refactor: update Jupyter notebook and main script for improved tensor operations and output display

This commit is contained in:
fada
2025-06-12 11:46:42 +08:00
parent 8bf116cbde
commit 97637b29a2
2 changed files with 211 additions and 18 deletions

227
04.ipynb
View File

@ -6,8 +6,8 @@
"metadata": {
"collapsed": true,
"ExecuteTime": {
"end_time": "2025-06-11T15:26:24.930498Z",
"start_time": "2025-06-11T15:26:24.925343Z"
"end_time": "2025-06-12T02:34:22.530839Z",
"start_time": "2025-06-12T02:34:20.159404Z"
}
},
"source": [
@ -23,18 +23,18 @@
"'2.2.1'"
]
},
"execution_count": 4,
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 4
"execution_count": 1
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-11T15:30:24.600025Z",
"start_time": "2025-06-11T15:30:24.594879Z"
"end_time": "2025-06-12T02:34:24.139479Z",
"start_time": "2025-06-12T02:34:24.116052Z"
}
},
"cell_type": "code",
@ -55,13 +55,13 @@
]
}
],
"execution_count": 11
"execution_count": 2
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-11T15:30:58.264992Z",
"start_time": "2025-06-11T15:30:58.260725Z"
"end_time": "2025-06-12T02:34:26.619937Z",
"start_time": "2025-06-12T02:34:26.608636Z"
}
},
"cell_type": "code",
@ -81,13 +81,13 @@
]
}
],
"execution_count": 13
"execution_count": 3
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-11T15:33:27.945096Z",
"start_time": "2025-06-11T15:33:27.939574Z"
"end_time": "2025-06-12T02:34:28.340829Z",
"start_time": "2025-06-12T02:34:28.333276Z"
}
},
"cell_type": "code",
@ -111,18 +111,211 @@
]
}
],
"execution_count": 18
"execution_count": 4
},
{
"metadata": {},
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-12T02:34:37.031549Z",
"start_time": "2025-06-12T02:34:36.991394Z"
}
},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": [
"x = torch.rand(2, 3, 5)\n",
"print(x.shape)\n",
"print(x)"
],
"id": "774add1439f9aa94",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([2, 3, 5])\n",
"tensor([[[0.1437, 0.3582, 0.4219, 0.4514, 0.6537],\n",
" [0.0089, 0.5737, 0.0201, 0.7728, 0.1827],\n",
" [0.6573, 0.1262, 0.0877, 0.2302, 0.0151]],\n",
"\n",
" [[0.0757, 0.7126, 0.4238, 0.0535, 0.0578],\n",
" [0.4909, 0.5616, 0.7342, 0.7925, 0.8879],\n",
" [0.3011, 0.1606, 0.2856, 0.8165, 0.4100]]])\n"
]
}
],
"execution_count": 6
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-12T02:34:48.128975Z",
"start_time": "2025-06-12T02:34:48.116080Z"
}
},
"cell_type": "code",
"source": [
"# 矩阵转秩\n",
"x = x.permute(2, 1, 0)\n",
"print(x.shape)\n",
"print(x)"
],
"id": "ceb1debce1c62ffd",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([5, 3, 2])\n",
"tensor([[[0.1437, 0.0757],\n",
" [0.0089, 0.4909],\n",
" [0.6573, 0.3011]],\n",
"\n",
" [[0.3582, 0.7126],\n",
" [0.5737, 0.5616],\n",
" [0.1262, 0.1606]],\n",
"\n",
" [[0.4219, 0.4238],\n",
" [0.0201, 0.7342],\n",
" [0.0877, 0.2856]],\n",
"\n",
" [[0.4514, 0.0535],\n",
" [0.7728, 0.7925],\n",
" [0.2302, 0.8165]],\n",
"\n",
" [[0.6537, 0.0578],\n",
" [0.1827, 0.8879],\n",
" [0.0151, 0.4100]]])\n"
]
}
],
"execution_count": 7
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-12T02:36:03.393949Z",
"start_time": "2025-06-12T02:36:03.381874Z"
}
},
"cell_type": "code",
"source": [
"x = torch.rand(2, 3, 4)\n",
"x = x.transpose(1, 0)\n",
"print(x.shape)"
],
"id": "774add1439f9aa94"
"id": "e56528fc1753cf04",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([3, 2, 4])\n"
]
}
],
"execution_count": 8
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-12T02:49:57.093978Z",
"start_time": "2025-06-12T02:49:57.088700Z"
}
},
"cell_type": "code",
"source": [
"x = torch.rand(4, 4)\n",
"x = x.view(2, 8)\n",
"x = x.permute(1, 0)\n",
"# x.view(4,4) # 不能直接用view因为view需要连续的内存\n",
"x.reshape(4, 4)"
],
"id": "74df80d1396ec1d3",
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([2, 8])"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 13
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-12T03:34:30.576671Z",
"start_time": "2025-06-12T03:34:30.569216Z"
}
},
"cell_type": "code",
"source": [
"# 增减维度\n",
"x = torch.rand(2, 1, 3)\n",
"print(x)\n",
"x = x.squeeze(1) # 去掉维度为1的维度\n",
"\n",
"print(x.shape)\n",
"print(x)\n",
"\n"
],
"id": "f5009aa3b8b1335c",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[[0.0287, 0.7995, 0.4072]],\n",
"\n",
" [[0.4378, 0.6384, 0.2777]]])\n",
"torch.Size([2, 3])\n",
"tensor([[0.0287, 0.7995, 0.4072],\n",
" [0.4378, 0.6384, 0.2777]])\n"
]
}
],
"execution_count": 29
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-12T03:42:20.284801Z",
"start_time": "2025-06-12T03:42:20.271042Z"
}
},
"cell_type": "code",
"source": [
"# 增减维度\n",
"x = torch.rand(2, 1, 3)\n",
"print(x)\n",
"x = x.unsqueeze() # 去掉维度为1的维度\n",
"\n",
"print(x.shape)\n",
"print(x)\n",
"\n"
],
"id": "dc138eb85bed2f3e",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[[0.4243, 0.1581, 0.4620]],\n",
"\n",
" [[0.8510, 0.5490, 0.7694]]])\n",
"torch.Size([2, 1, 1, 3])\n",
"tensor([[[[0.4243, 0.1581, 0.4620]]],\n",
"\n",
"\n",
" [[[0.8510, 0.5490, 0.7694]]]])\n"
]
}
],
"execution_count": 30
}
],
"metadata": {

View File

@ -1,4 +1,4 @@
import torch
import numpy as np
torch.__version__
print(torch.__version__)