feat: refactor Jupyter notebook for convolution layer demonstration with updated input tensor and fixed kernel parameters

This commit is contained in:
fada
2025-06-17 20:55:32 +08:00
parent c8b00eea2d
commit 30607ae7c1

119
10.ipynb Normal file
View File

@ -0,0 +1,119 @@
{
"cells": [
{
"cell_type": "code",
"id": "initial_id",
"metadata": {
"collapsed": true,
"ExecuteTime": {
"end_time": "2025-06-17T02:29:56.329629Z",
"start_time": "2025-06-17T02:29:53.851211Z"
}
},
"source": [
"import torch\n",
"import torch.nn as nn"
],
"outputs": [],
"execution_count": 1
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-17T02:30:23.709542Z",
"start_time": "2025-06-17T02:30:23.705090Z"
}
},
"cell_type": "code",
"source": [
"x = torch.randn(3, 5, 5).unsqueeze(0)\n",
"print(x.shape)"
],
"id": "19a429395361b901",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([1, 3, 5, 5])\n"
]
}
],
"execution_count": 3
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-17T02:32:50.342968Z",
"start_time": "2025-06-17T02:32:50.335771Z"
}
},
"cell_type": "code",
"source": [
"# 请注意DW中输入特征通道数与输出通道数是一样的\n",
"in_channels_dw = x.shape[1]\n",
"out_channels_dw = x.shape[1]\n",
"# 一般来讲DW卷积的kernel size 为3\n",
"kernel_size_dw = 3\n",
"stride_dw = 1\n",
"\n",
"# DW 卷积groups参数与输入通道数一样\n",
"dw = nn.Conv2d(in_channels_dw, out_channels_dw, kernel_size_dw, stride=stride_dw, groups=in_channels_dw)"
],
"id": "23ca0000610f16a0",
"outputs": [],
"execution_count": 4
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-17T02:34:22.427522Z",
"start_time": "2025-06-17T02:34:22.369462Z"
}
},
"cell_type": "code",
"source": [
"in_channels_pw = out_channels_dw\n",
"out_channels_pw = 4\n",
"kernel_size_pw = 1\n",
"\n",
"pw = nn.Conv2d(in_channels_pw, out_channels_pw, kernel_size_pw, stride=1)\n",
"\n",
"out = pw(dw(x))\n",
"print(out.shape)"
],
"id": "d0bb1304d1f98d2e",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([1, 4, 3, 3])\n"
]
}
],
"execution_count": 5
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}