Commit 33993069 by 前钰

Upload New File

parent d22e6233
import torch
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
import wavelet # wavelet.py 中应包含 create_wavelet_filter、wavelet_transform、inverse_wavelet_transform
class WTConv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=5, stride=1, bias=True, wt_levels=1, wt_type='db1'):
super(WTConv2d, self).__init__()
assert in_channels == out_channels # 强制要求输入和输出通道数一致,便于小波分解和重构时通道数不变
self.in_channels = in_channels # 输入通道数
self.wt_levels = wt_levels # 小波分解的层数(level)
self.stride = stride # 步长(用于可选的下采样)
self.dilation = 1 # 默认膨胀卷积系数为1(即正常卷积)
# 创建小波滤波器(wt_filter)和逆小波滤波器(iwt_filter)
self.wt_filter, self.iwt_filter = wavelet.create_wavelet_filter(wt_type, in_channels, in_channels, torch.float)
self.wt_filter = nn.Parameter(self.wt_filter, requires_grad=False) # 固定参数,不参与训练
self.iwt_filter = nn.Parameter(self.iwt_filter, requires_grad=False)
# 部分应用小波变换函数,用于后续传参简化
self.wt_function = partial(wavelet.wavelet_transform, filters=self.wt_filter)
self.iwt_function = partial(wavelet.inverse_wavelet_transform, filters=self.iwt_filter)
# 基础卷积层,分组卷积(groups=in_channels),每个通道独立卷积
self.base_conv = nn.Conv2d(in_channels, in_channels, kernel_size, padding='same', stride=1, dilation=1, groups=in_channels, bias=bias)
self.base_scale = _ScaleModule([1, in_channels, 1, 1]) # 对基础卷积结果进行缩放(可学习权重)
# 为每个小波层级准备一组卷积操作(作用于4倍通道:LL, LH, HL, HH)
self.wavelet_convs = nn.ModuleList([
nn.Conv2d(in_channels * 4, in_channels * 4, kernel_size, padding='same', stride=1, dilation=1, groups=in_channels * 4, bias=False)
for _ in range(self.wt_levels)
])
# 每层对应的缩放模块(用于控制小波处理影响强度)
self.wavelet_scale = nn.ModuleList([
_ScaleModule([1, in_channels * 4, 1, 1], init_scale=0.1) for _ in range(self.wt_levels)
])
# 如果 stride > 1,定义可选的下采样方式(深度可分离)
if self.stride > 1:
self.stride_filter = nn.Parameter(torch.ones(in_channels, 1, 1, 1), requires_grad=False) # 深度分离
self.do_stride = lambda x_in: F.conv2d(x_in, self.stride_filter, bias=None, stride=self.stride, groups=in_channels)
else:
self.do_stride = None # 无需下采样
def forward(self, x):
x_ll_in_levels = [] # 存储每一层的小波分解得到的低频成分
x_h_in_levels = [] # 存储每一层的小波分解得到的高频成分
shapes_in_levels = [] # 存储每层原始输入大小,便于重建
curr_x_ll = x # 初始低频成分设为输入
# 逐层小波分解
for i in range(self.wt_levels):
curr_shape = curr_x_ll.shape # 获取当前分辨率
shapes_in_levels.append(curr_shape) # 保存形状用于后续还原
# 如果宽或高为奇数,进行 padding 补足为偶数
if (curr_shape[2] % 2 > 0) or (curr_shape[3] % 2 > 0):
curr_pads = (0, curr_shape[3] % 2, 0, curr_shape[2] % 2)
curr_x_ll = F.pad(curr_x_ll, curr_pads)
# 小波变换
curr_x = self.wt_function(curr_x_ll) # 输出: [B, C, 4, H/2, W/2],4 表示 LL, LH, HL, HH
curr_x_ll = curr_x[:, :, 0, :, :] # 仅保留 LL 分量
# reshape 展平以便使用 2D 卷积处理所有子带
shape_x = curr_x.shape # [B, C, 4, H/2, W/2]
curr_x_tag = curr_x.reshape(shape_x[0], shape_x[1] * 4, shape_x[3], shape_x[4])
curr_x_tag = self.wavelet_scale[i](self.wavelet_convs[i](curr_x_tag)) # 卷积 + 缩放
curr_x_tag = curr_x_tag.reshape(shape_x) # reshape 回原 shape
# 拆分成 LL 和其他高频分量(LH、HL、HH)
x_ll_in_levels.append(curr_x_tag[:, :, 0, :, :]) # 保存 LL
x_h_in_levels.append(curr_x_tag[:, :, 1:4, :, :]) # 保存其余三个分量
next_x_ll = 0 # 初始化下一层 LL
# 逐层逆小波重构(从高层往低层)
for i in range(self.wt_levels - 1, -1, -1):
curr_x_ll = x_ll_in_levels.pop() # 当前层 LL
curr_x_h = x_h_in_levels.pop() # 当前层 高频分量
curr_shape = shapes_in_levels.pop() # 当前层输入形状
curr_x_ll = curr_x_ll + next_x_ll # 加上上一层重构的 LL 分量
# 拼接 LL 和高频子带,准备重构
curr_x = torch.cat([curr_x_ll.unsqueeze(2), curr_x_h], dim=2) # 拼接成 [B, C, 4, H, W]
next_x_ll = self.iwt_function(curr_x) # 逆小波还原
# 截断尺寸,恢复到之前 pad 前的尺寸
next_x_ll = next_x_ll[:, :, :curr_shape[2], :curr_shape[3]]
x_tag = next_x_ll # 小波重构后的输出
assert len(x_ll_in_levels) == 0 # 检查栈是否清空,避免错误
# 基础分组卷积处理输入 x,再通过 scale 缩放
x = self.base_scale(self.base_conv(x))
x = x + x_tag # 与小波重构结果相加,实现融合
if self.do_stride is not None:
x = self.do_stride(x) # 可选的下采样
return x # 输出融合结果
class _ScaleModule(nn.Module):
def __init__(self, dims, init_scale=1.0, init_bias=0):
super(_ScaleModule, self).__init__()
self.dims = dims
self.weight = nn.Parameter(torch.ones(*dims) * init_scale) # 可学习缩放因子,初始为1.0或指定值
self.bias = None # 可选偏置项,这里未用
def forward(self, x):
return torch.mul(self.weight, x) # 每通道缩放
# 使用示例
in_channels = 8
out_channels = 8 # 必须与 in_channels 一致
input_tensor = torch.randn(1, in_channels, 64, 64) # 输入张量: batch=1, 8通道, 64x64 图像
model = WTConv2d(
in_channels=in_channels, # 输入特征图的通道数(例如RGB图像就是3)
out_channels=out_channels, # 输出特征图的通道数(卷积后的输出通道)
kernel_size=3, # 卷积核大小,标准卷积部分用的,比如3x3
stride=1, # 卷积的步长(stride),控制滑动窗口的移动
wt_levels=2, # 小波分解的层数(level),比如2表示进行两层小波分解
wt_type='db1' # 小波类型,比如 'db1' 表示使用 Daubechies 1 小波
)
output_tensor = model(input_tensor) # 推理
print("输入张量形状:", input_tensor.shape)
print("输出张量形状:", output_tensor.shape)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment