Commit 91e3d0cf by 前钰

Upload New File

parent 68ddd0a6
# 此训练脚本更改了第10行,47行
# 此训练脚本更改了第10行,47行
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from tqdm import tqdm
from model import SimpleCNN,BetterCNN
# ------------------- argparse 参数解析部分 -------------------
parser = argparse.ArgumentParser(description="猫狗分类训练脚本")
parser.add_argument("--model_name", type=str, default="alexnet", help="模型名称,例如 SimpleCNN")
parser.add_argument("--lr", type=float, default=0.001, help="学习率")
parser.add_argument("--batch_size", type=int, default=32, help="批大小")
parser.add_argument("--epochs", type=int, default=10, help="训练轮次")
parser.add_argument("--optimizer", type=str, default="adam", choices=["adam", "sgd"], help="优化器类型")
parser.add_argument("--train_dir", type=str, default=r"data/train", help="训练集路径")
parser.add_argument("--val_dir", type=str, default=r"data/val", help="验证集路径")
parser.add_argument("--save_path", type=str, default="cat_dog_cnn_new.pth", help="模型保存路径")
args = parser.parse_args()
# ------------------- 数据预处理 -------------------
train_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.RandomCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
val_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
train_dataset = ImageFolder(args.train_dir, transform=train_transform)
val_dataset = ImageFolder(args.val_dir, transform=val_transform)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0)
# ------------------- 模型定义 -------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BetterCNN().to(device)
# ------------------- 损失函数 -------------------
criterion = nn.CrossEntropyLoss()
# ------------------- 优化器选择 -------------------
if args.optimizer == "adam":
optimizer = optim.Adam(model.parameters(), lr=args.lr)
elif args.optimizer == "sgd":
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)
# ------------------- 训练函数 -------------------
def train_model(model, train_loader, val_loader, criterion, optimizer, epochs=10):
for epoch in range(epochs):
model.train()
train_loss, correct, total = 0, 0, 0
for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item() * images.size(0)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
acc = 100 * correct / total
print(f"Train Loss: {train_loss/total:.4f}, Accuracy: {acc:.2f}%")
model.eval()
val_loss, correct, total = 0, 0, 0
with torch.no_grad():
for images, labels in val_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
val_loss += loss.item() * images.size(0)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
val_acc = 100 * correct / total
print(f"Val Loss: {val_loss/total:.4f}, Accuracy: {val_acc:.2f}%\n")
# ------------------- 启动训练 -------------------
train_model(model, train_loader, val_loader, criterion, optimizer, epochs=args.epochs)
# ------------------- 保存模型 -------------------
torch.save(model.state_dict(), args.save_path)
print(f"模型已保存至 {args.save_path}")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment