Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
点
点头人工智能课程-v6.0-影像
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
靓靓
点头人工智能课程-v6.0-影像
Commits
33993069
Commit
33993069
authored
Aug 04, 2025
by
前钰
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Upload New File
parent
d22e6233
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
138 additions
and
0 deletions
+138
-0
wtconv2d.py
4-模型改进/4.3-特征增强(下)/Conv/wtconv2d.py
+138
-0
No files found.
4-模型改进/4.3-特征增强(下)/Conv/wtconv2d.py
0 → 100644
View file @
33993069
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
functools
import
partial
import
wavelet
# wavelet.py 中应包含 create_wavelet_filter、wavelet_transform、inverse_wavelet_transform
class
WTConv2d
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
=
5
,
stride
=
1
,
bias
=
True
,
wt_levels
=
1
,
wt_type
=
'db1'
):
super
(
WTConv2d
,
self
)
.
__init__
()
assert
in_channels
==
out_channels
# 强制要求输入和输出通道数一致,便于小波分解和重构时通道数不变
self
.
in_channels
=
in_channels
# 输入通道数
self
.
wt_levels
=
wt_levels
# 小波分解的层数(level)
self
.
stride
=
stride
# 步长(用于可选的下采样)
self
.
dilation
=
1
# 默认膨胀卷积系数为1(即正常卷积)
# 创建小波滤波器(wt_filter)和逆小波滤波器(iwt_filter)
self
.
wt_filter
,
self
.
iwt_filter
=
wavelet
.
create_wavelet_filter
(
wt_type
,
in_channels
,
in_channels
,
torch
.
float
)
self
.
wt_filter
=
nn
.
Parameter
(
self
.
wt_filter
,
requires_grad
=
False
)
# 固定参数,不参与训练
self
.
iwt_filter
=
nn
.
Parameter
(
self
.
iwt_filter
,
requires_grad
=
False
)
# 部分应用小波变换函数,用于后续传参简化
self
.
wt_function
=
partial
(
wavelet
.
wavelet_transform
,
filters
=
self
.
wt_filter
)
self
.
iwt_function
=
partial
(
wavelet
.
inverse_wavelet_transform
,
filters
=
self
.
iwt_filter
)
# 基础卷积层,分组卷积(groups=in_channels),每个通道独立卷积
self
.
base_conv
=
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
,
padding
=
'same'
,
stride
=
1
,
dilation
=
1
,
groups
=
in_channels
,
bias
=
bias
)
self
.
base_scale
=
_ScaleModule
([
1
,
in_channels
,
1
,
1
])
# 对基础卷积结果进行缩放(可学习权重)
# 为每个小波层级准备一组卷积操作(作用于4倍通道:LL, LH, HL, HH)
self
.
wavelet_convs
=
nn
.
ModuleList
([
nn
.
Conv2d
(
in_channels
*
4
,
in_channels
*
4
,
kernel_size
,
padding
=
'same'
,
stride
=
1
,
dilation
=
1
,
groups
=
in_channels
*
4
,
bias
=
False
)
for
_
in
range
(
self
.
wt_levels
)
])
# 每层对应的缩放模块(用于控制小波处理影响强度)
self
.
wavelet_scale
=
nn
.
ModuleList
([
_ScaleModule
([
1
,
in_channels
*
4
,
1
,
1
],
init_scale
=
0.1
)
for
_
in
range
(
self
.
wt_levels
)
])
# 如果 stride > 1,定义可选的下采样方式(深度可分离)
if
self
.
stride
>
1
:
self
.
stride_filter
=
nn
.
Parameter
(
torch
.
ones
(
in_channels
,
1
,
1
,
1
),
requires_grad
=
False
)
# 深度分离
self
.
do_stride
=
lambda
x_in
:
F
.
conv2d
(
x_in
,
self
.
stride_filter
,
bias
=
None
,
stride
=
self
.
stride
,
groups
=
in_channels
)
else
:
self
.
do_stride
=
None
# 无需下采样
def
forward
(
self
,
x
):
x_ll_in_levels
=
[]
# 存储每一层的小波分解得到的低频成分
x_h_in_levels
=
[]
# 存储每一层的小波分解得到的高频成分
shapes_in_levels
=
[]
# 存储每层原始输入大小,便于重建
curr_x_ll
=
x
# 初始低频成分设为输入
# 逐层小波分解
for
i
in
range
(
self
.
wt_levels
):
curr_shape
=
curr_x_ll
.
shape
# 获取当前分辨率
shapes_in_levels
.
append
(
curr_shape
)
# 保存形状用于后续还原
# 如果宽或高为奇数,进行 padding 补足为偶数
if
(
curr_shape
[
2
]
%
2
>
0
)
or
(
curr_shape
[
3
]
%
2
>
0
):
curr_pads
=
(
0
,
curr_shape
[
3
]
%
2
,
0
,
curr_shape
[
2
]
%
2
)
curr_x_ll
=
F
.
pad
(
curr_x_ll
,
curr_pads
)
# 小波变换
curr_x
=
self
.
wt_function
(
curr_x_ll
)
# 输出: [B, C, 4, H/2, W/2],4 表示 LL, LH, HL, HH
curr_x_ll
=
curr_x
[:,
:,
0
,
:,
:]
# 仅保留 LL 分量
# reshape 展平以便使用 2D 卷积处理所有子带
shape_x
=
curr_x
.
shape
# [B, C, 4, H/2, W/2]
curr_x_tag
=
curr_x
.
reshape
(
shape_x
[
0
],
shape_x
[
1
]
*
4
,
shape_x
[
3
],
shape_x
[
4
])
curr_x_tag
=
self
.
wavelet_scale
[
i
](
self
.
wavelet_convs
[
i
](
curr_x_tag
))
# 卷积 + 缩放
curr_x_tag
=
curr_x_tag
.
reshape
(
shape_x
)
# reshape 回原 shape
# 拆分成 LL 和其他高频分量(LH、HL、HH)
x_ll_in_levels
.
append
(
curr_x_tag
[:,
:,
0
,
:,
:])
# 保存 LL
x_h_in_levels
.
append
(
curr_x_tag
[:,
:,
1
:
4
,
:,
:])
# 保存其余三个分量
next_x_ll
=
0
# 初始化下一层 LL
# 逐层逆小波重构(从高层往低层)
for
i
in
range
(
self
.
wt_levels
-
1
,
-
1
,
-
1
):
curr_x_ll
=
x_ll_in_levels
.
pop
()
# 当前层 LL
curr_x_h
=
x_h_in_levels
.
pop
()
# 当前层 高频分量
curr_shape
=
shapes_in_levels
.
pop
()
# 当前层输入形状
curr_x_ll
=
curr_x_ll
+
next_x_ll
# 加上上一层重构的 LL 分量
# 拼接 LL 和高频子带,准备重构
curr_x
=
torch
.
cat
([
curr_x_ll
.
unsqueeze
(
2
),
curr_x_h
],
dim
=
2
)
# 拼接成 [B, C, 4, H, W]
next_x_ll
=
self
.
iwt_function
(
curr_x
)
# 逆小波还原
# 截断尺寸,恢复到之前 pad 前的尺寸
next_x_ll
=
next_x_ll
[:,
:,
:
curr_shape
[
2
],
:
curr_shape
[
3
]]
x_tag
=
next_x_ll
# 小波重构后的输出
assert
len
(
x_ll_in_levels
)
==
0
# 检查栈是否清空,避免错误
# 基础分组卷积处理输入 x,再通过 scale 缩放
x
=
self
.
base_scale
(
self
.
base_conv
(
x
))
x
=
x
+
x_tag
# 与小波重构结果相加,实现融合
if
self
.
do_stride
is
not
None
:
x
=
self
.
do_stride
(
x
)
# 可选的下采样
return
x
# 输出融合结果
class
_ScaleModule
(
nn
.
Module
):
def
__init__
(
self
,
dims
,
init_scale
=
1.0
,
init_bias
=
0
):
super
(
_ScaleModule
,
self
)
.
__init__
()
self
.
dims
=
dims
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
*
dims
)
*
init_scale
)
# 可学习缩放因子,初始为1.0或指定值
self
.
bias
=
None
# 可选偏置项,这里未用
def
forward
(
self
,
x
):
return
torch
.
mul
(
self
.
weight
,
x
)
# 每通道缩放
# 使用示例
in_channels
=
8
out_channels
=
8
# 必须与 in_channels 一致
input_tensor
=
torch
.
randn
(
1
,
in_channels
,
64
,
64
)
# 输入张量: batch=1, 8通道, 64x64 图像
model
=
WTConv2d
(
in_channels
=
in_channels
,
# 输入特征图的通道数(例如RGB图像就是3)
out_channels
=
out_channels
,
# 输出特征图的通道数(卷积后的输出通道)
kernel_size
=
3
,
# 卷积核大小,标准卷积部分用的,比如3x3
stride
=
1
,
# 卷积的步长(stride),控制滑动窗口的移动
wt_levels
=
2
,
# 小波分解的层数(level),比如2表示进行两层小波分解
wt_type
=
'db1'
# 小波类型,比如 'db1' 表示使用 Daubechies 1 小波
)
output_tensor
=
model
(
input_tensor
)
# 推理
print
(
"输入张量形状:"
,
input_tensor
.
shape
)
print
(
"输出张量形状:"
,
output_tensor
.
shape
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment