Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
人
人工智能系统实战第三期
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
liyinkai
人工智能系统实战第三期
Commits
6c412873
Commit
6c412873
authored
Oct 30, 2023
by
前钰
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Upload New File
parent
c904d45d
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
124 additions
and
0 deletions
+124
-0
fully_connected_network.py
...战第三期/实战代码/深度学习项目实战/神经网络实现鸢尾花分类/fully_connected_network.py
+124
-0
No files found.
人工智能系统实战第三期/实战代码/深度学习项目实战/神经网络实现鸢尾花分类/fully_connected_network.py
0 → 100644
View file @
6c412873
import
os
import
os
import
argparse
import
sys
from
torch.utils.data
import
DataLoader
from
tqdm
import
tqdm
import
torch
import
torch.optim
as
optim
import
torch.nn
as
nn
from
data_loader
import
iris_dataload
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--num_classes'
,
type
=
int
,
default
=
100
,
help
=
'the number of classes'
)
parser
.
add_argument
(
'--epochs'
,
type
=
int
,
default
=
20
,
help
=
'the number of training epoch'
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
16
,
help
=
'batch_size for training'
)
parser
.
add_argument
(
'--lr'
,
type
=
float
,
default
=
0.005
,
help
=
'star learning rate'
)
parser
.
add_argument
(
'--data_path'
,
type
=
str
,
default
=
"/mnt/d/Codes/GNN/NN/Iris_data.txt"
)
parser
.
add_argument
(
'--device'
,
default
=
'cuda'
,
help
=
'device id (i.e. 0 or 0,1 or cpu)'
)
opt
=
parser
.
parse_args
()
# 初始化神经网络
class
Neuralnetwork
(
nn
.
Module
):
def
__init__
(
self
,
in_dim
,
n_hidden_1
,
n_hidden_2
,
out_dim
):
super
(
Neuralnetwork
,
self
)
.
__init__
()
self
.
layer1
=
nn
.
Linear
(
in_dim
,
n_hidden_1
)
self
.
layer2
=
nn
.
Linear
(
n_hidden_1
,
n_hidden_2
)
self
.
layer3
=
nn
.
Linear
(
n_hidden_2
,
out_dim
)
def
forward
(
self
,
x
):
x
=
self
.
layer1
(
x
)
x
=
self
.
layer2
(
x
)
x
=
self
.
layer3
(
x
)
return
x
# 定义当前模型的训练环境
device
=
torch
.
device
(
opt
.
device
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
# device = torch.device("cpu")
# 划分数据集并加载
custom_dataset
=
iris_dataload
(
"./Iris_data.txt"
)
train_size
=
int
(
len
(
custom_dataset
)
*
0.7
)
validate_size
=
int
(
len
(
custom_dataset
)
*
0.2
)
test_size
=
len
(
custom_dataset
)
-
validate_size
-
train_size
train_dataset
,
validate_dataset
,
test_dataset
=
torch
.
utils
.
data
.
random_split
(
custom_dataset
,
[
train_size
,
validate_size
,
test_size
])
train_loader
=
DataLoader
(
train_dataset
,
batch_size
=
opt
.
batch_size
,
shuffle
=
False
)
validate_loader
=
DataLoader
(
validate_dataset
,
batch_size
=
1
,
shuffle
=
False
)
test_loader
=
DataLoader
(
test_dataset
,
batch_size
=
1
,
shuffle
=
False
)
print
(
"Training set data size:"
,
len
(
train_loader
)
*
opt
.
batch_size
,
",Validating set data size:"
,
len
(
validate_loader
),
",Testing set data size:"
,
len
(
test_loader
))
# 定义推理过程,返回准确率。用于验证阶段和测试阶段
def
infer
(
model
,
dataset
,
device
):
model
.
eval
()
acc_num
=
0.0
with
torch
.
no_grad
():
for
data
in
dataset
:
datas
,
labels
=
data
outputs
=
model
(
datas
.
to
(
device
))
predict_y
=
torch
.
max
(
outputs
,
dim
=
1
)[
1
]
acc_num
+=
torch
.
eq
(
predict_y
,
labels
.
to
(
device
))
.
sum
()
.
item
()
accuratcy
=
acc_num
/
len
(
dataset
)
return
accuratcy
# 定义训练,验证和测试过程
def
main
(
args
):
print
(
args
)
model
=
Neuralnetwork
(
4
,
12
,
6
,
3
)
.
to
(
device
)
# 实例化模型
loss_function
=
nn
.
CrossEntropyLoss
()
# 定义损失函数
pg
=
[
p
for
p
in
model
.
parameters
()
if
p
.
requires_grad
]
# 定义模型参数
optimizer
=
optim
.
Adam
(
pg
,
lr
=
args
.
lr
)
# 定义优化器
# 定义模型权重存储地址
save_path
=
os
.
path
.
join
(
os
.
getcwd
(),
'results/weights'
)
if
os
.
path
.
exists
(
save_path
)
is
False
:
os
.
makedirs
(
save_path
)
# 开始训练过程
for
epoch
in
range
(
opt
.
epochs
):
############################################################## train ######################################################
model
.
train
()
acc_num
=
torch
.
zeros
(
1
)
.
to
(
device
)
# 初始化,用于计算训练过程中预测正确的数量
sample_num
=
0
# 初始化,用于记录当前迭代中,已经计算了多少个样本
# tqdm是一个进度条显示器,可以在终端打印出现在的训练进度
train_bar
=
tqdm
(
train_loader
,
file
=
sys
.
stdout
,
ncols
=
100
)
for
datas
in
train_bar
:
data
,
label
=
datas
label
=
label
.
squeeze
(
-
1
)
sample_num
+=
data
.
shape
[
0
]
optimizer
.
zero_grad
()
outputs
=
model
(
data
.
to
(
device
))
# output_shape: [batch_size, num_classes]
pred_class
=
torch
.
max
(
outputs
,
dim
=
1
)[
1
]
# torch.max 返回值是一个tuple,第一个元素是max值,第二个元素是max值的索引。
acc_num
+=
torch
.
eq
(
pred_class
,
label
.
to
(
device
))
.
sum
()
loss
=
loss_function
(
outputs
,
label
.
to
(
device
))
# 求损失
loss
.
backward
()
# 自动求导
optimizer
.
step
()
# 梯度下降
# print statistics
train_acc
=
acc_num
.
item
()
/
sample_num
train_bar
.
desc
=
"train epoch[{}/{}] loss:{:.3f}"
.
format
(
epoch
+
1
,
opt
.
epochs
,
loss
)
############################################################## validate ######################################################
val_accurate
=
infer
(
model
=
model
,
dataset
=
validate_loader
,
device
=
device
)
print
(
'[epoch
%
d] train_loss:
%.3
f train_acc:
%.3
f val_accuracy:
%.3
f'
%
(
epoch
+
1
,
loss
,
train_acc
,
val_accurate
))
torch
.
save
(
model
.
state_dict
(),
os
.
path
.
join
(
save_path
,
"NN.pth"
)
)
# 每次迭代后清空这些指标,重新计算
train_acc
=
0.0
val_accurate
=
0.0
print
(
'Finished Training'
)
################################################################# test ############################################################
test_accurate
=
infer
(
model
=
model
,
dataset
=
test_loader
,
device
=
device
)
print
(
' test_accuracy:
%.3
f'
%
(
test_accurate
))
if
__name__
==
'__main__'
:
main
(
opt
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment