feat: refactor Jupyter notebook for convolution layer demonstration with updated input tensor and fixed kernel parameters
This commit is contained in:
119
10.ipynb
Normal file
119
10.ipynb
Normal file
@ -0,0 +1,119 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"id": "initial_id",
|
||||
"metadata": {
|
||||
"collapsed": true,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2025-06-17T02:29:56.329629Z",
|
||||
"start_time": "2025-06-17T02:29:53.851211Z"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"import torch.nn as nn"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": 1
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2025-06-17T02:30:23.709542Z",
|
||||
"start_time": "2025-06-17T02:30:23.705090Z"
|
||||
}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"x = torch.randn(3, 5, 5).unsqueeze(0)\n",
|
||||
"print(x.shape)"
|
||||
],
|
||||
"id": "19a429395361b901",
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"torch.Size([1, 3, 5, 5])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"execution_count": 3
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2025-06-17T02:32:50.342968Z",
|
||||
"start_time": "2025-06-17T02:32:50.335771Z"
|
||||
}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# 请注意DW中,输入特征通道数与输出通道数是一样的\n",
|
||||
"in_channels_dw = x.shape[1]\n",
|
||||
"out_channels_dw = x.shape[1]\n",
|
||||
"# 一般来讲DW卷积的kernel size 为3\n",
|
||||
"kernel_size_dw = 3\n",
|
||||
"stride_dw = 1\n",
|
||||
"\n",
|
||||
"# DW 卷积groups参数与输入通道数一样\n",
|
||||
"dw = nn.Conv2d(in_channels_dw, out_channels_dw, kernel_size_dw, stride=stride_dw, groups=in_channels_dw)"
|
||||
],
|
||||
"id": "23ca0000610f16a0",
|
||||
"outputs": [],
|
||||
"execution_count": 4
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2025-06-17T02:34:22.427522Z",
|
||||
"start_time": "2025-06-17T02:34:22.369462Z"
|
||||
}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"in_channels_pw = out_channels_dw\n",
|
||||
"out_channels_pw = 4\n",
|
||||
"kernel_size_pw = 1\n",
|
||||
"\n",
|
||||
"pw = nn.Conv2d(in_channels_pw, out_channels_pw, kernel_size_pw, stride=1)\n",
|
||||
"\n",
|
||||
"out = pw(dw(x))\n",
|
||||
"print(out.shape)"
|
||||
],
|
||||
"id": "d0bb1304d1f98d2e",
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"torch.Size([1, 4, 3, 3])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"execution_count": 5
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 2
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython2",
|
||||
"version": "2.7.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
Reference in New Issue
Block a user