Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
人
人工智能系统实战第三期
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Charles
人工智能系统实战第三期
Commits
1ce6841d
Commit
1ce6841d
authored
Dec 24, 2023
by
前钰
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Upload New File
parent
2c80ef8e
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
107 additions
and
0 deletions
+107
-0
utils.py
人工智能系统实战第三期/实战代码/深度学习项目实战/扩散模型作业/DDPM/utils/utils.py
+107
-0
No files found.
人工智能系统实战第三期/实战代码/深度学习项目实战/扩散模型作业/DDPM/utils/utils.py
0 → 100644
View file @
1ce6841d
import
itertools
import
itertools
import
math
from
functools
import
partial
import
matplotlib.pyplot
as
plt
import
numpy
as
np
import
torch
#---------------------------------------------------------#
# 将图像转换成RGB图像,防止灰度图在预测时报错。
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
#---------------------------------------------------------#
def
cvtColor
(
image
):
if
len
(
np
.
shape
(
image
))
==
3
and
np
.
shape
(
image
)[
2
]
==
3
:
return
image
else
:
image
=
image
.
convert
(
'RGB'
)
return
image
#----------------------------------------#
# 预处理训练图片
#----------------------------------------#
def
preprocess_input
(
x
):
x
/=
255
x
-=
0.5
x
/=
0.5
return
x
def
postprocess_output
(
x
):
x
*=
0.5
x
+=
0.5
x
*=
255
return
x
def
show_result
(
num_epoch
,
net
,
device
):
test_images
=
net
.
sample
(
4
,
device
)
size_figure_grid
=
2
fig
,
ax
=
plt
.
subplots
(
size_figure_grid
,
size_figure_grid
,
figsize
=
(
5
,
5
))
for
i
,
j
in
itertools
.
product
(
range
(
size_figure_grid
),
range
(
size_figure_grid
)):
ax
[
i
,
j
]
.
get_xaxis
()
.
set_visible
(
False
)
ax
[
i
,
j
]
.
get_yaxis
()
.
set_visible
(
False
)
for
k
in
range
(
2
*
2
):
i
=
k
//
2
j
=
k
%
2
ax
[
i
,
j
]
.
cla
()
ax
[
i
,
j
]
.
imshow
(
np
.
uint8
(
postprocess_output
(
test_images
[
k
]
.
cpu
()
.
data
.
numpy
()
.
transpose
(
1
,
2
,
0
))))
label
=
'Epoch {0}'
.
format
(
num_epoch
)
fig
.
text
(
0.5
,
0.04
,
label
,
ha
=
'center'
)
plt
.
savefig
(
"results/train_out/epoch_"
+
str
(
num_epoch
)
+
"_results.png"
)
plt
.
close
(
'all'
)
#避免内存泄漏
def
show_config
(
**
kwargs
):
print
(
'Configurations:'
)
print
(
'-'
*
70
)
print
(
'|
%25
s |
%40
s|'
%
(
'keys'
,
'values'
))
print
(
'-'
*
70
)
for
key
,
value
in
kwargs
.
items
():
print
(
'|
%25
s |
%40
s|'
%
(
str
(
key
),
str
(
value
)))
print
(
'-'
*
70
)
#---------------------------------------------------#
# 获得学习率
#---------------------------------------------------#
def
get_lr
(
optimizer
):
for
param_group
in
optimizer
.
param_groups
:
return
param_group
[
'lr'
]
def
get_lr_scheduler
(
lr_decay_type
,
lr
,
min_lr
,
total_iters
,
warmup_iters_ratio
=
0.05
,
warmup_lr_ratio
=
0.1
,
no_aug_iter_ratio
=
0.05
,
step_num
=
10
):
def
yolox_warm_cos_lr
(
lr
,
min_lr
,
total_iters
,
warmup_total_iters
,
warmup_lr_start
,
no_aug_iter
,
iters
):
if
iters
<=
warmup_total_iters
:
# lr = (lr - warmup_lr_start) * iters / float(warmup_total_iters) + warmup_lr_start
lr
=
(
lr
-
warmup_lr_start
)
*
pow
(
iters
/
float
(
warmup_total_iters
),
2
)
+
warmup_lr_start
elif
iters
>=
total_iters
-
no_aug_iter
:
lr
=
min_lr
else
:
lr
=
min_lr
+
0.5
*
(
lr
-
min_lr
)
*
(
1.0
+
math
.
cos
(
math
.
pi
*
(
iters
-
warmup_total_iters
)
/
(
total_iters
-
warmup_total_iters
-
no_aug_iter
))
)
return
lr
def
step_lr
(
lr
,
decay_rate
,
step_size
,
iters
):
if
step_size
<
1
:
raise
ValueError
(
"step_size must above 1."
)
n
=
iters
//
step_size
out_lr
=
lr
*
decay_rate
**
n
return
out_lr
if
lr_decay_type
==
"cos"
:
warmup_total_iters
=
min
(
max
(
warmup_iters_ratio
*
total_iters
,
1
),
3
)
warmup_lr_start
=
max
(
warmup_lr_ratio
*
lr
,
1e-6
)
no_aug_iter
=
min
(
max
(
no_aug_iter_ratio
*
total_iters
,
1
),
15
)
func
=
partial
(
yolox_warm_cos_lr
,
lr
,
min_lr
,
total_iters
,
warmup_total_iters
,
warmup_lr_start
,
no_aug_iter
)
else
:
decay_rate
=
(
min_lr
/
lr
)
**
(
1
/
(
step_num
-
1
))
step_size
=
total_iters
/
step_num
func
=
partial
(
step_lr
,
lr
,
decay_rate
,
step_size
)
return
func
def
set_optimizer_lr
(
optimizer
,
lr_scheduler_func
,
epoch
):
lr
=
lr_scheduler_func
(
epoch
)
for
param_group
in
optimizer
.
param_groups
:
param_group
[
'lr'
]
=
lr
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment