Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
人
人工智能系统实战第三期
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ding
人工智能系统实战第三期
Commits
ec0185af
Commit
ec0185af
authored
Dec 09, 2023
by
前钰
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Upload New File
parent
2df4106f
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
46 additions
and
0 deletions
+46
-0
main.py
...学习项目实战/基于transformer的花朵识别/ViTFlowerClassification/main.py
+46
-0
No files found.
人工智能系统实战第三期/实战代码/深度学习项目实战/基于transformer的花朵识别/ViTFlowerClassification/main.py
0 → 100644
View file @
ec0185af
import
os
import
os
import
time
from
train
import
train
from
predict
import
predict
if
__name__
==
'__main__'
:
args
=
{
# 部分参数需要斟酌调整大小!
'num_classes'
:
5
,
# 手动设置分成几类。该项目为花朵分类,所以设置为5类
'label_name'
:
[
"daisy"
,
"dandelion"
,
"roses"
,
"sunflowers"
,
"tulips"
],
# 手动设置标签名称
'epochs'
:
100
,
# 设置训练的轮数
'batch_size'
:
128
,
# 设置每批读入的图片数量
'lr'
:
1e-3
,
# 设置学习率
'lrf'
:
1e-2
,
# 设置学习率优化策略的参数
'train_dir'
:
'./flower/train'
,
# 设置训练集路径
'val_dir'
:
'./flower/val'
,
# 设置测试集路径
'summary_dir'
:
'./summary'
,
# 设置训练结果与日志的存储路径
'use_weights'
:
True
,
#
'gpu_list'
:
'0'
,
# 对于多GPU训练,设置GPU列表
'model_type'
:
'vit_base_patch16_224'
,
# 选择一个模型进行训练
}
# 此处代码根据需要,二选一运行即可
# 如果使用预训练模型,设置使用预训练的模型的名称和位置
args
[
'weights_name'
]
=
str
(
'/pretrained_model/'
+
args
[
'model_type'
]
+
'_'
+
time
.
strftime
(
"
%
Y-
%
m-
%
d"
,
time
.
localtime
())
+
'.pth'
)
# 如果从头训练,则置为空字符串
args
[
'weights_name'
]
=
''
# 给系统设置多gpu并行训练的gpu列表
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
args
[
'gpu_list'
]
# print('------------------训练阶段------------------')
# # 进行训练
# train(args=args)
print
(
'------------------测试阶段------------------'
)
args
[
'predicted_image'
]
=
"./flower/val/daisy/173350276_02817aa8d5.jpg"
args
[
'saved_pth'
]
=
"{}/weights/epoch=74_val_acc=0.5220.pth"
.
format
(
args
[
'summary_dir'
])
#选择一张图片进行测试
predict
(
args
=
args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment