Commit 0ec407a0 by Leo

upload code

parent dce1ff65
# 代码资料
# 代码资料
train.zip
链接: https://pan.baidu.com/s/1tpzBvQsleRTaCJuJqopQzw?pwd=gdqa
提取码: gdqa
# 数据部分作业:
# 数据部分作业:
参照课上讲的代码,对猫狗分类数据集进行训练集和验证集的划分(9:1),并给出此数据集的dataset和dataloader代码。最终输出形式:获取一个批次的数据...
批次形状: 图像 torch.Size([32, 3, 224, 224]), 标签 torch.Size([32])。
\ No newline at end of file
import os
import os
import random
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image
from pathlib import Path
from tqdm import tqdm
import os
# 修复 OpenMP 冲突
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
# 设置随机种子确保可重复性
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
use_pin_memory = device.type == 'cuda'
class CatDogDataset(Dataset):
"""猫狗分类自定义数据集类"""
def __init__(self, data_dir, file_list, transform=None):
"""
初始化数据集
参数:
data_dir: 数据目录路径
file_list: 包含文件名的列表
transform: 数据预处理变换
"""
self.data_dir = data_dir
self.file_list = file_list
self.transform = transform
self.class_to_idx = {'cat': 0, 'dog': 1} # 类别到索引的映射
def __len__(self):
"""返回数据集大小"""
return len(self.file_list)
def __getitem__(self, idx):
"""获取单个样本"""
filename = self.file_list[idx]
img_path = os.path.join(self.data_dir, filename)
# 打开图像文件
image = Image.open(img_path).convert('RGB')
# 根据文件名确定标签 (cat.*.jpg 或 dog.*.jpg)
label = 0 if filename.startswith('cat') else 1
# 应用变换
if self.transform:
image = self.transform(image)
return image, label
def get_file_list(data_dir):
"""获取数据目录中所有图像文件的列表"""
return [f for f in os.listdir(data_dir) if f.endswith(('.jpg', '.jpeg', '.png'))]
def split_dataset(file_list, train_ratio=0.9):
"""划分数据集为训练集和验证集"""
random.shuffle(file_list)
split_idx = int(len(file_list) * train_ratio)
train_files = file_list[:split_idx]
val_files = file_list[split_idx:]
return train_files, val_files
def visualize_batch(images, labels, save_path='batch_visualization.png'):
"""可视化一个批次的图像和标签"""
# 反归一化图像
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
images = images.numpy().transpose((0, 2, 3, 1))
images = std * images + mean
images = np.clip(images, 0, 1)
# 创建子图
batch_size = images.shape[0]
fig, axes = plt.subplots(4, 8, figsize=(20, 10))
axes = axes.flatten()
for i in range(batch_size):
# 显示图像
ax = axes[i]
ax.imshow(images[i])
ax.axis('off')
# 添加标签
label = 'Cat' if labels[i] == 0 else 'Dog'
ax.set_title(label, fontsize=12)
# 调整布局并保存
plt.tight_layout()
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.close()
print(f"批量可视化结果已保存至: {save_path}")
def main():
# 设置路径和参数
DATA_DIR = 'train' # 数据集目录
BATCH_SIZE = 32 # 批次大小
IMAGE_SIZE = 224 # 图像大小
# 创建输出目录
Path('output').mkdir(exist_ok=True)
# 1. 获取文件列表
print("获取文件列表...")
all_files = get_file_list(DATA_DIR)
print(f"找到 {len(all_files)} 张图片")
# 2. 划分数据集
print("划分数据集 (9:1)...")
train_files, val_files = split_dataset(all_files, train_ratio=0.9)
print(f"训练集大小: {len(train_files)} 张图片")
print(f"验证集大小: {len(val_files)} 张图片")
# 3. 定义数据预处理
train_transform = transforms.Compose([
transforms.RandomResizedCrop(IMAGE_SIZE),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
val_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(IMAGE_SIZE),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# 4. 创建数据集
print("创建数据集...")
train_dataset = CatDogDataset(DATA_DIR, train_files, transform=train_transform)
val_dataset = CatDogDataset(DATA_DIR, val_files, transform=val_transform)
# 5. 创建数据加载器
print("创建数据加载器...")
train_loader = DataLoader(
train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=4,
pin_memory=use_pin_memory
)
val_loader = DataLoader(
val_dataset,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=2,
pin_memory=use_pin_memory
)
# 6. 获取一个批次的数据并可视化
print("获取一个批次的数据...")
images, labels = next(iter(train_loader))
print(f"批次形状: 图像 {images.shape}, 标签 {labels.shape}")
# 7. 可视化并保存结果
print("可视化批次数据...")
visualize_batch(images, labels, save_path='output/batch_visualization.png')
# 8. 保存划分信息
with open('output/dataset_split.txt', 'w') as f:
f.write(f"总样本数: {len(all_files)}\n")
f.write(f"训练集大小: {len(train_files)} ({len(train_files)/len(all_files):.1%})\n")
f.write(f"验证集大小: {len(val_files)} ({len(val_files)/len(all_files):.1%})\n")
f.write("\n前10个训练样本:\n")
for file in train_files[:10]:
f.write(f"{file}\n")
f.write("\n前10个验证样本:\n")
for file in val_files[:10]:
f.write(f"{file}\n")
print("数据集划分信息已保存至: output/dataset_split.txt")
print("处理完成!")
if __name__ == '__main__':
main()
\ No newline at end of file
# 代码资料
# 代码资料
COVID_Classification.zip
链接: https://pan.baidu.com/s/1XCtGt-NH2b8-b5otUa8LPw?pwd=yrv9
提取码: yrv9
-- "a/3-\347\256\227\346\263\225\350\256\262\350\247\243/3.2-\345\233\276\345\203\217\345\210\206\347\261\273\357\274\210\344\270\213\357\274\211/.gitkeep"
++ /dev/null
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment