Files
pytorch-study/04.ipynb

343 lines
7.0 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "code",
"id": "initial_id",
"metadata": {
"collapsed": true,
"ExecuteTime": {
"end_time": "2025-06-12T02:34:22.530839Z",
"start_time": "2025-06-12T02:34:20.159404Z"
}
},
"source": [
"import torch\n",
"import numpy as np\n",
"\n",
"torch.__version__"
],
"outputs": [
{
"data": {
"text/plain": [
"'2.2.1'"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 1
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-12T02:34:24.139479Z",
"start_time": "2025-06-12T02:34:24.116052Z"
}
},
"cell_type": "code",
"source": [
"a = torch.tensor(1)\n",
"b = a.item()\n",
"print(a)\n",
"print(b)"
],
"id": "ec73dc2f6feeece4",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(1)\n",
"1\n"
]
}
],
"execution_count": 2
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-12T02:34:26.619937Z",
"start_time": "2025-06-12T02:34:26.608636Z"
}
},
"cell_type": "code",
"source": [
"a = [1, 2, 3]\n",
"b = torch.tensor(a)\n",
"c = b.numpy().tolist()\n",
"print(c)"
],
"id": "6c3e0063d8fcc299",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[1, 2, 3]\n"
]
}
],
"execution_count": 3
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-12T02:34:28.340829Z",
"start_time": "2025-06-12T02:34:28.333276Z"
}
},
"cell_type": "code",
"source": [
"a = torch.zeros(2, 3, 5)\n",
"print(a.shape)\n",
"\n",
"print(a.size())\n",
"\n",
"print(a.numel())\n"
],
"id": "d04c60d3f01351c2",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([2, 3, 5])\n",
"torch.Size([2, 3, 5])\n",
"30\n"
]
}
],
"execution_count": 4
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-12T02:34:37.031549Z",
"start_time": "2025-06-12T02:34:36.991394Z"
}
},
"cell_type": "code",
"source": [
"x = torch.rand(2, 3, 5)\n",
"print(x.shape)\n",
"print(x)"
],
"id": "774add1439f9aa94",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([2, 3, 5])\n",
"tensor([[[0.1437, 0.3582, 0.4219, 0.4514, 0.6537],\n",
" [0.0089, 0.5737, 0.0201, 0.7728, 0.1827],\n",
" [0.6573, 0.1262, 0.0877, 0.2302, 0.0151]],\n",
"\n",
" [[0.0757, 0.7126, 0.4238, 0.0535, 0.0578],\n",
" [0.4909, 0.5616, 0.7342, 0.7925, 0.8879],\n",
" [0.3011, 0.1606, 0.2856, 0.8165, 0.4100]]])\n"
]
}
],
"execution_count": 6
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-12T02:34:48.128975Z",
"start_time": "2025-06-12T02:34:48.116080Z"
}
},
"cell_type": "code",
"source": [
"# 矩阵转秩\n",
"x = x.permute(2, 1, 0)\n",
"print(x.shape)\n",
"print(x)"
],
"id": "ceb1debce1c62ffd",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([5, 3, 2])\n",
"tensor([[[0.1437, 0.0757],\n",
" [0.0089, 0.4909],\n",
" [0.6573, 0.3011]],\n",
"\n",
" [[0.3582, 0.7126],\n",
" [0.5737, 0.5616],\n",
" [0.1262, 0.1606]],\n",
"\n",
" [[0.4219, 0.4238],\n",
" [0.0201, 0.7342],\n",
" [0.0877, 0.2856]],\n",
"\n",
" [[0.4514, 0.0535],\n",
" [0.7728, 0.7925],\n",
" [0.2302, 0.8165]],\n",
"\n",
" [[0.6537, 0.0578],\n",
" [0.1827, 0.8879],\n",
" [0.0151, 0.4100]]])\n"
]
}
],
"execution_count": 7
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-12T02:36:03.393949Z",
"start_time": "2025-06-12T02:36:03.381874Z"
}
},
"cell_type": "code",
"source": [
"x = torch.rand(2, 3, 4)\n",
"x = x.transpose(1, 0)\n",
"print(x.shape)"
],
"id": "e56528fc1753cf04",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([3, 2, 4])\n"
]
}
],
"execution_count": 8
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-12T02:49:57.093978Z",
"start_time": "2025-06-12T02:49:57.088700Z"
}
},
"cell_type": "code",
"source": [
"x = torch.rand(4, 4)\n",
"x = x.view(2, 8)\n",
"x = x.permute(1, 0)\n",
"# x.view(4,4) # 不能直接用view因为view需要连续的内存\n",
"x.reshape(4, 4)"
],
"id": "74df80d1396ec1d3",
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([2, 8])"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 13
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-12T03:34:30.576671Z",
"start_time": "2025-06-12T03:34:30.569216Z"
}
},
"cell_type": "code",
"source": [
"# 增减维度\n",
"x = torch.rand(2, 1, 3)\n",
"print(x)\n",
"x = x.squeeze(1) # 去掉维度为1的维度\n",
"\n",
"print(x.shape)\n",
"print(x)\n",
"\n"
],
"id": "f5009aa3b8b1335c",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[[0.0287, 0.7995, 0.4072]],\n",
"\n",
" [[0.4378, 0.6384, 0.2777]]])\n",
"torch.Size([2, 3])\n",
"tensor([[0.0287, 0.7995, 0.4072],\n",
" [0.4378, 0.6384, 0.2777]])\n"
]
}
],
"execution_count": 29
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-12T03:42:20.284801Z",
"start_time": "2025-06-12T03:42:20.271042Z"
}
},
"cell_type": "code",
"source": [
"# 增减维度\n",
"x = torch.rand(2, 1, 3)\n",
"print(x)\n",
"x = x.unsqueeze() # 去掉维度为1的维度\n",
"\n",
"print(x.shape)\n",
"print(x)\n",
"\n"
],
"id": "dc138eb85bed2f3e",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[[0.4243, 0.1581, 0.4620]],\n",
"\n",
" [[0.8510, 0.5490, 0.7694]]])\n",
"torch.Size([2, 1, 1, 3])\n",
"tensor([[[[0.4243, 0.1581, 0.4620]]],\n",
"\n",
"\n",
" [[[0.8510, 0.5490, 0.7694]]]])\n"
]
}
],
"execution_count": 30
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}