refactor: update Jupyter notebook and main script for improved tensor operations and output display

This commit is contained in:
fada
2025-06-12 11:46:42 +08:00
parent 8bf116cbde
commit 97637b29a2
2 changed files with 211 additions and 18 deletions

227
04.ipynb
View File

@ -6,8 +6,8 @@
"metadata": { "metadata": {
"collapsed": true, "collapsed": true,
"ExecuteTime": { "ExecuteTime": {
"end_time": "2025-06-11T15:26:24.930498Z", "end_time": "2025-06-12T02:34:22.530839Z",
"start_time": "2025-06-11T15:26:24.925343Z" "start_time": "2025-06-12T02:34:20.159404Z"
} }
}, },
"source": [ "source": [
@ -23,18 +23,18 @@
"'2.2.1'" "'2.2.1'"
] ]
}, },
"execution_count": 4, "execution_count": 1,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
], ],
"execution_count": 4 "execution_count": 1
}, },
{ {
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
"end_time": "2025-06-11T15:30:24.600025Z", "end_time": "2025-06-12T02:34:24.139479Z",
"start_time": "2025-06-11T15:30:24.594879Z" "start_time": "2025-06-12T02:34:24.116052Z"
} }
}, },
"cell_type": "code", "cell_type": "code",
@ -55,13 +55,13 @@
] ]
} }
], ],
"execution_count": 11 "execution_count": 2
}, },
{ {
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
"end_time": "2025-06-11T15:30:58.264992Z", "end_time": "2025-06-12T02:34:26.619937Z",
"start_time": "2025-06-11T15:30:58.260725Z" "start_time": "2025-06-12T02:34:26.608636Z"
} }
}, },
"cell_type": "code", "cell_type": "code",
@ -81,13 +81,13 @@
] ]
} }
], ],
"execution_count": 13 "execution_count": 3
}, },
{ {
"metadata": { "metadata": {
"ExecuteTime": { "ExecuteTime": {
"end_time": "2025-06-11T15:33:27.945096Z", "end_time": "2025-06-12T02:34:28.340829Z",
"start_time": "2025-06-11T15:33:27.939574Z" "start_time": "2025-06-12T02:34:28.333276Z"
} }
}, },
"cell_type": "code", "cell_type": "code",
@ -111,18 +111,211 @@
] ]
} }
], ],
"execution_count": 18 "execution_count": 4
}, },
{ {
"metadata": {}, "metadata": {
"ExecuteTime": {
"end_time": "2025-06-12T02:34:37.031549Z",
"start_time": "2025-06-12T02:34:36.991394Z"
}
},
"cell_type": "code", "cell_type": "code",
"outputs": [],
"execution_count": null,
"source": [ "source": [
"x = torch.rand(2, 3, 5)\n", "x = torch.rand(2, 3, 5)\n",
"print(x.shape)\n",
"print(x)"
],
"id": "774add1439f9aa94",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([2, 3, 5])\n",
"tensor([[[0.1437, 0.3582, 0.4219, 0.4514, 0.6537],\n",
" [0.0089, 0.5737, 0.0201, 0.7728, 0.1827],\n",
" [0.6573, 0.1262, 0.0877, 0.2302, 0.0151]],\n",
"\n",
" [[0.0757, 0.7126, 0.4238, 0.0535, 0.0578],\n",
" [0.4909, 0.5616, 0.7342, 0.7925, 0.8879],\n",
" [0.3011, 0.1606, 0.2856, 0.8165, 0.4100]]])\n"
]
}
],
"execution_count": 6
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-12T02:34:48.128975Z",
"start_time": "2025-06-12T02:34:48.116080Z"
}
},
"cell_type": "code",
"source": [
"# 矩阵转秩\n",
"x = x.permute(2, 1, 0)\n",
"print(x.shape)\n",
"print(x)"
],
"id": "ceb1debce1c62ffd",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([5, 3, 2])\n",
"tensor([[[0.1437, 0.0757],\n",
" [0.0089, 0.4909],\n",
" [0.6573, 0.3011]],\n",
"\n",
" [[0.3582, 0.7126],\n",
" [0.5737, 0.5616],\n",
" [0.1262, 0.1606]],\n",
"\n",
" [[0.4219, 0.4238],\n",
" [0.0201, 0.7342],\n",
" [0.0877, 0.2856]],\n",
"\n",
" [[0.4514, 0.0535],\n",
" [0.7728, 0.7925],\n",
" [0.2302, 0.8165]],\n",
"\n",
" [[0.6537, 0.0578],\n",
" [0.1827, 0.8879],\n",
" [0.0151, 0.4100]]])\n"
]
}
],
"execution_count": 7
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-12T02:36:03.393949Z",
"start_time": "2025-06-12T02:36:03.381874Z"
}
},
"cell_type": "code",
"source": [
"x = torch.rand(2, 3, 4)\n",
"x = x.transpose(1, 0)\n",
"print(x.shape)" "print(x.shape)"
], ],
"id": "774add1439f9aa94" "id": "e56528fc1753cf04",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([3, 2, 4])\n"
]
}
],
"execution_count": 8
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-12T02:49:57.093978Z",
"start_time": "2025-06-12T02:49:57.088700Z"
}
},
"cell_type": "code",
"source": [
"x = torch.rand(4, 4)\n",
"x = x.view(2, 8)\n",
"x = x.permute(1, 0)\n",
"# x.view(4,4) # 不能直接用view因为view需要连续的内存\n",
"x.reshape(4, 4)"
],
"id": "74df80d1396ec1d3",
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([2, 8])"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 13
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-12T03:34:30.576671Z",
"start_time": "2025-06-12T03:34:30.569216Z"
}
},
"cell_type": "code",
"source": [
"# 增减维度\n",
"x = torch.rand(2, 1, 3)\n",
"print(x)\n",
"x = x.squeeze(1) # 去掉维度为1的维度\n",
"\n",
"print(x.shape)\n",
"print(x)\n",
"\n"
],
"id": "f5009aa3b8b1335c",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[[0.0287, 0.7995, 0.4072]],\n",
"\n",
" [[0.4378, 0.6384, 0.2777]]])\n",
"torch.Size([2, 3])\n",
"tensor([[0.0287, 0.7995, 0.4072],\n",
" [0.4378, 0.6384, 0.2777]])\n"
]
}
],
"execution_count": 29
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-12T03:42:20.284801Z",
"start_time": "2025-06-12T03:42:20.271042Z"
}
},
"cell_type": "code",
"source": [
"# 增减维度\n",
"x = torch.rand(2, 1, 3)\n",
"print(x)\n",
"x = x.unsqueeze() # 去掉维度为1的维度\n",
"\n",
"print(x.shape)\n",
"print(x)\n",
"\n"
],
"id": "dc138eb85bed2f3e",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[[0.4243, 0.1581, 0.4620]],\n",
"\n",
" [[0.8510, 0.5490, 0.7694]]])\n",
"torch.Size([2, 1, 1, 3])\n",
"tensor([[[[0.4243, 0.1581, 0.4620]]],\n",
"\n",
"\n",
" [[[0.8510, 0.5490, 0.7694]]]])\n"
]
}
],
"execution_count": 30
} }
], ],
"metadata": { "metadata": {

View File

@ -1,4 +1,4 @@
import torch import torch
import numpy as np import numpy as np
torch.__version__ print(torch.__version__)