Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
人
人工智能系统实战第三期
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ding
人工智能系统实战第三期
Commits
0f0d0878
Commit
0f0d0878
authored
Mar 23, 2024
by
前钰
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Upload New File
parent
15e6fd24
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
120 additions
and
0 deletions
+120
-0
utils.py
人工智能系统实战第三期/实战代码/计算机视觉/diffusion变体/utils.py
+120
-0
No files found.
人工智能系统实战第三期/实战代码/计算机视觉/diffusion变体/utils.py
0 → 100644
View file @
0f0d0878
import
random
import
random
import
time
import
datetime
import
sys
from
torch.autograd
import
Variable
import
torch
from
visdom
import
Visdom
import
numpy
as
np
def
tensor2image
(
tensor
):
image
=
127.5
*
(
tensor
[
0
]
.
cpu
()
.
float
()
.
numpy
()
+
1.0
)
if
image
.
shape
[
0
]
==
1
:
image
=
np
.
tile
(
image
,
(
3
,
1
,
1
))
return
image
.
astype
(
np
.
uint8
)
class
Logger
():
def
__init__
(
self
,
n_epochs
,
batches_epoch
):
self
.
viz
=
Visdom
()
self
.
n_epochs
=
n_epochs
self
.
batches_epoch
=
batches_epoch
self
.
epoch
=
1
self
.
batch
=
1
self
.
prev_time
=
time
.
time
()
self
.
mean_period
=
0
self
.
losses
=
{}
self
.
loss_windows
=
{}
self
.
image_windows
=
{}
def
log
(
self
,
losses
=
None
,
images
=
None
):
self
.
mean_period
+=
(
time
.
time
()
-
self
.
prev_time
)
self
.
prev_time
=
time
.
time
()
sys
.
stdout
.
write
(
'
\r
Epoch
%03
d/
%03
d [
%04
d/
%04
d] -- '
%
(
self
.
epoch
,
self
.
n_epochs
,
self
.
batch
,
self
.
batches_epoch
))
for
i
,
loss_name
in
enumerate
(
losses
.
keys
()):
if
loss_name
not
in
self
.
losses
:
self
.
losses
[
loss_name
]
=
losses
[
loss_name
]
.
data
else
:
self
.
losses
[
loss_name
]
+=
losses
[
loss_name
]
.
data
if
(
i
+
1
)
==
len
(
losses
.
keys
()):
sys
.
stdout
.
write
(
'
%
s:
%.4
f -- '
%
(
loss_name
,
self
.
losses
[
loss_name
]
/
self
.
batch
))
else
:
sys
.
stdout
.
write
(
'
%
s:
%.4
f | '
%
(
loss_name
,
self
.
losses
[
loss_name
]
/
self
.
batch
))
batches_done
=
self
.
batches_epoch
*
(
self
.
epoch
-
1
)
+
self
.
batch
batches_left
=
self
.
batches_epoch
*
(
self
.
n_epochs
-
self
.
epoch
)
+
self
.
batches_epoch
-
self
.
batch
sys
.
stdout
.
write
(
'ETA:
%
s'
%
(
datetime
.
timedelta
(
seconds
=
batches_left
*
self
.
mean_period
/
batches_done
)))
# Draw images
for
image_name
,
tensor
in
images
.
items
():
if
image_name
not
in
self
.
image_windows
:
self
.
image_windows
[
image_name
]
=
self
.
viz
.
image
(
tensor2image
(
tensor
.
data
),
opts
=
{
'title'
:
image_name
})
else
:
self
.
viz
.
image
(
tensor2image
(
tensor
.
data
),
win
=
self
.
image_windows
[
image_name
],
opts
=
{
'title'
:
image_name
})
# End of epoch
if
(
self
.
batch
%
self
.
batches_epoch
)
==
0
:
# Plot losses
for
loss_name
,
loss
in
self
.
losses
.
items
():
loss
=
loss
.
cpu
()
if
loss_name
not
in
self
.
loss_windows
:
self
.
loss_windows
[
loss_name
]
=
self
.
viz
.
line
(
X
=
np
.
array
([
self
.
epoch
]),
Y
=
np
.
array
([
loss
/
self
.
batch
]),
opts
=
{
'xlabel'
:
'epochs'
,
'ylabel'
:
loss_name
,
'title'
:
loss_name
})
else
:
self
.
viz
.
line
(
X
=
np
.
array
([
self
.
epoch
]),
Y
=
np
.
array
([
loss
/
self
.
batch
]),
win
=
self
.
loss_windows
[
loss_name
],
update
=
'append'
)
# Reset losses for next epoch
self
.
losses
[
loss_name
]
=
0.0
self
.
epoch
+=
1
self
.
batch
=
1
sys
.
stdout
.
write
(
'
\n
'
)
else
:
self
.
batch
+=
1
class
ReplayBuffer
():
def
__init__
(
self
,
max_size
=
50
):
assert
(
max_size
>
0
),
'Empty buffer or trying to create a black hole. Be careful.'
self
.
max_size
=
max_size
self
.
data
=
[]
def
push_and_pop
(
self
,
data
):
to_return
=
[]
for
element
in
data
.
data
:
element
=
torch
.
unsqueeze
(
element
,
0
)
if
len
(
self
.
data
)
<
self
.
max_size
:
self
.
data
.
append
(
element
)
to_return
.
append
(
element
)
else
:
if
random
.
uniform
(
0
,
1
)
>
0.5
:
i
=
random
.
randint
(
0
,
self
.
max_size
-
1
)
to_return
.
append
(
self
.
data
[
i
]
.
clone
())
self
.
data
[
i
]
=
element
else
:
to_return
.
append
(
element
)
return
Variable
(
torch
.
cat
(
to_return
))
class
LambdaLR
():
def
__init__
(
self
,
n_epochs
,
offset
,
decay_start_epoch
):
assert
((
n_epochs
-
decay_start_epoch
)
>
0
),
"Decay must start before the training session ends!"
self
.
n_epochs
=
n_epochs
self
.
offset
=
offset
self
.
decay_start_epoch
=
decay_start_epoch
def
step
(
self
,
epoch
):
return
1.0
-
max
(
0
,
epoch
+
self
.
offset
-
self
.
decay_start_epoch
)
/
(
self
.
n_epochs
-
self
.
decay_start_epoch
)
def
weights_init_normal
(
m
):
classname
=
m
.
__class__
.
__name__
if
classname
.
find
(
'Conv'
)
!=
-
1
:
torch
.
nn
.
init
.
normal
(
m
.
weight
.
data
,
0.0
,
0.02
)
elif
classname
.
find
(
'BatchNorm2d'
)
!=
-
1
:
torch
.
nn
.
init
.
normal
(
m
.
weight
.
data
,
1.0
,
0.02
)
torch
.
nn
.
init
.
constant
(
m
.
bias
.
data
,
0.0
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment