Commit 74e84cec by 前钰

Upload New File

parent a6270e6d
import torch
import torch
import torch.nn as nn
import torch.nn.functional as F
class DepthwiseSeparableConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
super(DepthwiseSeparableConv, self).__init__()
# Step 1: Depthwise convolution(每个输入通道单独卷积)
self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size,
stride=stride, padding=padding, groups=in_channels, bias=False)
# Step 2: Pointwise convolution(1x1卷积,用于通道之间的信息融合)
self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1,
stride=1, padding=0, bias=False)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
out = self.depthwise(x) # [B, C, H, W]
out = self.pointwise(out) # [B, C_out, H, W]
out = self.bn(out)
out = self.relu(out)
return out
model = DepthwiseSeparableConv(in_channels=32, out_channels=32)
input_tensor = torch.randn(1, 32, 128, 128) # batch size 1, 32通道,128x128图像
output = model(input_tensor)
print(output.shape) # 输出: torch.Size([1, 64, 128, 128])
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment