{ "cells": [ { "cell_type": "code", "id": "initial_id", "metadata": { "collapsed": true }, "source": [ "import torch\n", "\n", "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")" ], "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "data = torch.ones(3, 3)\n", "print(data.device)" ], "id": "7a630763614905d", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "device = torch.device(\"cuda:0\")\n", "data_gpu = data.to(device)\n", "print(data_gpu.device)" ], "id": "e2a2d8a6d60231c", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "import torch.nn as nn\n", "\n", "net = nn.Sequential(nn.Linear(3, 3))\n", "net.to(device)" ], "id": "458ea27224fd0061", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "from torch import nn\n", "\n", "\n", "class ASimpleNet(nn.Module):\n", " def __init__(self, layers=3):\n", " super(ASimpleNet, self).__init__()\n", " self.linears = nn.ModuleList([nn.Linear(3, 3, bias=False) for i in range(layers)])\n", "\n", " def forward(self, x):\n", " print(\"forward batchsize is: {}\".format(x.size()[0]))\n", " x = self.linears(x)\n", " x = torch.relu(x)\n", " return x" ], "id": "4859aa95dd22d01d", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "batch_size = 16\n", "inputs = torch.randn(batch_size, 3)\n", "labels = torch.randn(batch_size, 3)\n", "inputs, labels = inputs.to(device), labels.to(device)\n", "net = ASimpleNet()\n", "net = nn.DataParallel(net)\n", "net.to(device)\n", "# print(\"CUDA_VISIBLE_DEVICES :{}\".format(os.environ[\"CUDA_VISIBLE_DEVICES\"]))\n", "\n", "for epoch in range(1):\n", " outputs = net(inputs)" ], "id": "d3eeb897f7f0ee68", "outputs": [], "execution_count": null } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 5 }