Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
点
点头人工智能课程-v6.0-通识
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
靓靓
点头人工智能课程-v6.0-通识
Commits
91e3d0cf
Commit
91e3d0cf
authored
Jul 01, 2025
by
前钰
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Upload New File
parent
68ddd0a6
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
97 additions
and
0 deletions
+97
-0
train_new.py
5-深度学习/5.6-作业/train_new.py
+97
-0
No files found.
5-深度学习/5.6-作业/train_new.py
0 → 100644
View file @
91e3d0cf
# 此训练脚本更改了第10行,47行
# 此训练脚本更改了第10行,47行
import
argparse
import
torch
import
torch.nn
as
nn
import
torch.optim
as
optim
from
torchvision
import
models
,
transforms
from
torchvision.datasets
import
ImageFolder
from
torch.utils.data
import
DataLoader
from
tqdm
import
tqdm
from
model
import
SimpleCNN
,
BetterCNN
# ------------------- argparse 参数解析部分 -------------------
parser
=
argparse
.
ArgumentParser
(
description
=
"猫狗分类训练脚本"
)
parser
.
add_argument
(
"--model_name"
,
type
=
str
,
default
=
"alexnet"
,
help
=
"模型名称,例如 SimpleCNN"
)
parser
.
add_argument
(
"--lr"
,
type
=
float
,
default
=
0.001
,
help
=
"学习率"
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
32
,
help
=
"批大小"
)
parser
.
add_argument
(
"--epochs"
,
type
=
int
,
default
=
10
,
help
=
"训练轮次"
)
parser
.
add_argument
(
"--optimizer"
,
type
=
str
,
default
=
"adam"
,
choices
=
[
"adam"
,
"sgd"
],
help
=
"优化器类型"
)
parser
.
add_argument
(
"--train_dir"
,
type
=
str
,
default
=
r"data/train"
,
help
=
"训练集路径"
)
parser
.
add_argument
(
"--val_dir"
,
type
=
str
,
default
=
r"data/val"
,
help
=
"验证集路径"
)
parser
.
add_argument
(
"--save_path"
,
type
=
str
,
default
=
"cat_dog_cnn_new.pth"
,
help
=
"模型保存路径"
)
args
=
parser
.
parse_args
()
# ------------------- 数据预处理 -------------------
train_transform
=
transforms
.
Compose
([
transforms
.
Resize
((
256
,
256
)),
transforms
.
RandomCrop
(
224
),
transforms
.
ToTensor
(),
transforms
.
Normalize
([
0.5
,
0.5
,
0.5
],
[
0.5
,
0.5
,
0.5
])
])
val_transform
=
transforms
.
Compose
([
transforms
.
Resize
((
256
,
256
)),
transforms
.
CenterCrop
(
224
),
transforms
.
ToTensor
(),
transforms
.
Normalize
([
0.5
,
0.5
,
0.5
],
[
0.5
,
0.5
,
0.5
])
])
train_dataset
=
ImageFolder
(
args
.
train_dir
,
transform
=
train_transform
)
val_dataset
=
ImageFolder
(
args
.
val_dir
,
transform
=
val_transform
)
train_loader
=
DataLoader
(
train_dataset
,
batch_size
=
args
.
batch_size
,
shuffle
=
True
,
num_workers
=
0
)
val_loader
=
DataLoader
(
val_dataset
,
batch_size
=
args
.
batch_size
,
shuffle
=
False
,
num_workers
=
0
)
# ------------------- 模型定义 -------------------
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
model
=
BetterCNN
()
.
to
(
device
)
# ------------------- 损失函数 -------------------
criterion
=
nn
.
CrossEntropyLoss
()
# ------------------- 优化器选择 -------------------
if
args
.
optimizer
==
"adam"
:
optimizer
=
optim
.
Adam
(
model
.
parameters
(),
lr
=
args
.
lr
)
elif
args
.
optimizer
==
"sgd"
:
optimizer
=
optim
.
SGD
(
model
.
parameters
(),
lr
=
args
.
lr
,
momentum
=
0.9
)
# ------------------- 训练函数 -------------------
def
train_model
(
model
,
train_loader
,
val_loader
,
criterion
,
optimizer
,
epochs
=
10
):
for
epoch
in
range
(
epochs
):
model
.
train
()
train_loss
,
correct
,
total
=
0
,
0
,
0
for
images
,
labels
in
tqdm
(
train_loader
,
desc
=
f
"Epoch {epoch+1}/{epochs}"
):
images
,
labels
=
images
.
to
(
device
),
labels
.
to
(
device
)
outputs
=
model
(
images
)
loss
=
criterion
(
outputs
,
labels
)
optimizer
.
zero_grad
()
loss
.
backward
()
optimizer
.
step
()
train_loss
+=
loss
.
item
()
*
images
.
size
(
0
)
_
,
predicted
=
torch
.
max
(
outputs
,
1
)
total
+=
labels
.
size
(
0
)
correct
+=
(
predicted
==
labels
)
.
sum
()
.
item
()
acc
=
100
*
correct
/
total
print
(
f
"Train Loss: {train_loss/total:.4f}, Accuracy: {acc:.2f}
%
"
)
model
.
eval
()
val_loss
,
correct
,
total
=
0
,
0
,
0
with
torch
.
no_grad
():
for
images
,
labels
in
val_loader
:
images
,
labels
=
images
.
to
(
device
),
labels
.
to
(
device
)
outputs
=
model
(
images
)
loss
=
criterion
(
outputs
,
labels
)
val_loss
+=
loss
.
item
()
*
images
.
size
(
0
)
_
,
predicted
=
torch
.
max
(
outputs
,
1
)
total
+=
labels
.
size
(
0
)
correct
+=
(
predicted
==
labels
)
.
sum
()
.
item
()
val_acc
=
100
*
correct
/
total
print
(
f
"Val Loss: {val_loss/total:.4f}, Accuracy: {val_acc:.2f}
%
\n
"
)
# ------------------- 启动训练 -------------------
train_model
(
model
,
train_loader
,
val_loader
,
criterion
,
optimizer
,
epochs
=
args
.
epochs
)
# ------------------- 保存模型 -------------------
torch
.
save
(
model
.
state_dict
(),
args
.
save_path
)
print
(
f
"模型已保存至 {args.save_path}"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment