Files
pytorch-study/10.ipynb

2.5 KiB
Raw Permalink Blame History

In [1]:
import torch
import torch.nn as nn
In [3]:
x = torch.randn(3, 5, 5).unsqueeze(0)
print(x.shape)
torch.Size([1, 3, 5, 5])
In [4]:
# 请注意DW中输入特征通道数与输出通道数是一样的
in_channels_dw = x.shape[1]
out_channels_dw = x.shape[1]
# 一般来讲DW卷积的kernel size 为3
kernel_size_dw = 3
stride_dw = 1

# DW 卷积groups参数与输入通道数一样
dw = nn.Conv2d(in_channels_dw, out_channels_dw, kernel_size_dw, stride=stride_dw, groups=in_channels_dw)
In [5]:
in_channels_pw = out_channels_dw
out_channels_pw = 4
kernel_size_pw = 1

pw = nn.Conv2d(in_channels_pw, out_channels_pw, kernel_size_pw, stride=1)

out = pw(dw(x))
print(out.shape)
torch.Size([1, 4, 3, 3])