2.5 KiB
2.5 KiB
In [1]:
import torch import torch.nn as nn
In [3]:
x = torch.randn(3, 5, 5).unsqueeze(0) print(x.shape)
torch.Size([1, 3, 5, 5])
In [4]:
# 请注意DW中,输入特征通道数与输出通道数是一样的 in_channels_dw = x.shape[1] out_channels_dw = x.shape[1] # 一般来讲DW卷积的kernel size 为3 kernel_size_dw = 3 stride_dw = 1 # DW 卷积groups参数与输入通道数一样 dw = nn.Conv2d(in_channels_dw, out_channels_dw, kernel_size_dw, stride=stride_dw, groups=in_channels_dw)
In [5]:
in_channels_pw = out_channels_dw out_channels_pw = 4 kernel_size_pw = 1 pw = nn.Conv2d(in_channels_pw, out_channels_pw, kernel_size_pw, stride=1) out = pw(dw(x)) print(out.shape)
torch.Size([1, 4, 3, 3])