Commit a2bfbc4c by 靓靓

upload files

parent 5518120b
++ "b/2-\345\272\224\347\224\250/2.06-AI\351\241\271\347\233\256\346\265\201\347\250\213/.gitkeep"
++ "b/2-\345\272\224\347\224\250/2.07-\345\233\276\345\203\217\350\257\206\345\210\253/.gitkeep"
++ "b/2-\345\272\224\347\224\250/2.07-\345\233\276\345\203\217\350\257\206\345\210\253/2.07.10-\345\233\276\345\203\217\350\257\206\345\210\253\346\250\241\345\236\213\345\272\223\347\232\204\345\260\201\350\243\205/.gitkeep"
import torch
import torch
import torch.nn as nn
class AlexNet(nn.Module):
def __init__(self, num_classes=1000):
super(AlexNet, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=2), # input[3, 224, 224] output[96, 55, 55]
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2), # output[96, 27, 27]
nn.Conv2d(96, 256, kernel_size=5, padding=2), # output[256, 27, 27]
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2), # output[256, 13, 13]
nn.Conv2d(256, 384, kernel_size=3, padding=1), # output[384, 13, 13]
nn.ReLU(inplace=True),
nn.Conv2d(384, 384, kernel_size=3, padding=1), # output[384, 13, 13]
nn.ReLU(inplace=True),
nn.Conv2d(384, 256, kernel_size=3, padding=1), # output[256, 13, 13]
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2), # output[256, 6, 6]
)
self.classifier = nn.Sequential(
nn.Dropout(p=0.5),
nn.Linear(256 * 6 * 6, 4096),
nn.ReLU(inplace=True),
nn.Dropout(p=0.5),
nn.Linear(4096, 1000),
nn.ReLU(inplace=True),
nn.Linear(1000, num_classes),
)
def forward(self, x):
x = self.features(x) # 256 * [6 * 6] 6x6的图片 / 矩阵
x = torch.flatten(x, start_dim=1) # 256 * 6 * 6 数列
x = self.classifier(x)
return x
\ No newline at end of file
import sys
import sys
from PyQt5.QtWidgets import QApplication, QMainWindow, QLabel, QVBoxLayout, QWidget, QPushButton, QFileDialog, QLineEdit
from PyQt5.QtGui import QPixmap
import os
import json
import torch.nn as nn
from AlexNet import AlexNet
import torch
from PIL import Image
from torchvision import transforms
class MainWindow(QMainWindow):
def __init__(self):
super().__init__()
self.setWindowTitle("图片上传和展示")
self.setGeometry(100, 100, 400, 300)
# 创建标签、按钮和文本框
self.image_label = QLabel(self)
self.upload_button = QPushButton("上传图片", self)
self.label_textbox = QLineEdit(self)
# 设置标签、按钮和文本框的位置和大小
self.image_label.setGeometry(10, 10, 380, 200)
self.upload_button.setGeometry(150, 220, 100, 30)
self.label_textbox.setGeometry(10, 250, 380, 30)
# 绑定按钮的点击事件
self.upload_button.clicked.connect(self.upload_image)
def upload_image(self):
# 打开文件对话框选择图片文件
file_dialog = QFileDialog()
image_path, _ = file_dialog.getOpenFileName(self, "选择图片", "", "Images (*.png *.xpm *.jpg *.bmp)")
# 加载并显示选中的图片
pixmap = QPixmap(image_path)
self.image_label.setPixmap(pixmap.scaled(self.image_label.size(), aspectRatioMode=True))
# 调用人工智能模型获取图片标签
label = self.get_image_label(image_path)
self.label_textbox.setText(label)
def get_image_label(self, image_path):
# 加载模型和类别标签
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# 加载图片
img = Image.open(image_path)
img = data_transform(img)
img = torch.unsqueeze(img, dim=0)
img = img.to(device)
# 加载类别标签
class_indict ={
"0": "daisy",
"1": "dandelion",
"2": "roses",
"3": "sunflowers",
"4": "tulips"
}
# 创建模型并加载权重
model = AlexNet(num_classes=5).to(device)
weights_path = "E:/PythonProject/AlexNetAPP/checkpoints/alex/alex_flower.pth"
assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
model.load_state_dict(torch.load(weights_path))
model.eval()
# 使用模型进行预测
with torch.no_grad():
output = torch.squeeze(model(img)).cpu()
predict = torch.softmax(output, dim=0)
predict_cla = torch.argmax(predict).numpy()
label = class_indict[str(predict_cla)]
return label
if __name__ == '__main__':
app = QApplication(sys.argv)
window = MainWindow()
window.show()
sys.exit(app.exec_())
\ No newline at end of file
权重文件下载链接<br>
权重文件下载链接<br>
链接:https://pan.baidu.com/s/1qjjfXn041S8fWVuwuRsX8w?pwd=DTAI <br>
提取码:DTAI
\ No newline at end of file
++ "b/2-\345\272\224\347\224\250/2.08-\350\257\255\344\271\211\345\210\206\345\211\262/.gitkeep"
++ "b/2-\345\272\224\347\224\250/2.08-\350\257\255\344\271\211\345\210\206\345\211\262/2.08.3-U-net\345\214\273\345\255\246\345\233\276\345\203\217\345\210\206\345\211\262/.gitkeep"
# U-Net: Semantic segmentation with PyTorch
# U-Net: Semantic segmentation with PyTorch
<a href="#"><img src="https://img.shields.io/github/actions/workflow/status/milesial/PyTorch-UNet/main.yml?logo=github&style=for-the-badge" /></a>
<a href="https://hub.docker.com/r/milesial/unet"><img src="https://img.shields.io/badge/docker%20image-available-blue?logo=Docker&style=for-the-badge" /></a>
<a href="https://pytorch.org/"><img src="https://img.shields.io/badge/PyTorch-v1.13+-red.svg?logo=PyTorch&style=for-the-badge" /></a>
<a href="#"><img src="https://img.shields.io/badge/python-v3.6+-blue.svg?logo=python&style=for-the-badge" /></a>
![input and output for a random image in the test dataset](https://i.imgur.com/GD8FcB7.png)
Customized implementation of the [U-Net](https://arxiv.org/abs/1505.04597) in PyTorch for Kaggle's [Carvana Image Masking Challenge](https://www.kaggle.com/c/carvana-image-masking-challenge) from high definition images.
- [Quick start](#quick-start)
- [Without Docker](#without-docker)
- [With Docker](#with-docker)
- [Description](#description)
- [Usage](#usage)
- [Docker](#docker)
- [Training](#training)
- [Prediction](#prediction)
- [Weights & Biases](#weights--biases)
- [Pretrained model](#pretrained-model)
- [Data](#data)
## Quick start
### Without Docker
1. [Install CUDA](https://developer.nvidia.com/cuda-downloads)
2. [Install PyTorch 1.13 or later](https://pytorch.org/get-started/locally/)
3. Install dependencies
```bash
pip install -r requirements.txt
```
4. Download the data and run training:
```bash
bash scripts/download_data.sh
python train.py --amp
```
### With Docker
1. [Install Docker 19.03 or later:](https://docs.docker.com/get-docker/)
```bash
curl https://get.docker.com | sh && sudo systemctl --now enable docker
```
2. [Install the NVIDIA container toolkit:](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html)
```bash
distribution=$(. /etc/os-release;echo $ID$VERSION_ID) \
&& curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | sudo apt-key add - \
&& curl -s -L https://nvidia.github.io/nvidia-docker/$distribution/nvidia-docker.list | sudo tee /etc/apt/sources.list.d/nvidia-docker.list
sudo apt-get update
sudo apt-get install -y nvidia-docker2
sudo systemctl restart docker
```
3. [Download and run the image:](https://hub.docker.com/repository/docker/milesial/unet)
```bash
sudo docker run --rm --shm-size=8g --ulimit memlock=-1 --gpus all -it milesial/unet
```
4. Download the data and run training:
```bash
bash scripts/download_data.sh
python train.py --amp
```
## Description
This model was trained from scratch with 5k images and scored a [Dice coefficient](https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient) of 0.988423 on over 100k test images.
It can be easily used for multiclass segmentation, portrait segmentation, medical segmentation, ...
## Usage
**Note : Use Python 3.6 or newer**
### Docker
A docker image containing the code and the dependencies is available on [DockerHub](https://hub.docker.com/repository/docker/milesial/unet).
You can download and jump in the container with ([docker >=19.03](https://docs.docker.com/get-docker/)):
```console
docker run -it --rm --shm-size=8g --ulimit memlock=-1 --gpus all milesial/unet
```
### Training
```console
> python train.py -h
usage: train.py [-h] [--epochs E] [--batch-size B] [--learning-rate LR]
[--load LOAD] [--scale SCALE] [--validation VAL] [--amp]
Train the UNet on images and target masks
optional arguments:
-h, --help show this help message and exit
--epochs E, -e E Number of epochs
--batch-size B, -b B Batch size
--learning-rate LR, -l LR
Learning rate
--load LOAD, -f LOAD Load model from a .pth file
--scale SCALE, -s SCALE
Downscaling factor of the images
--validation VAL, -v VAL
Percent of the data that is used as validation (0-100)
--amp Use mixed precision
```
By default, the `scale` is 0.5, so if you wish to obtain better results (but use more memory), set it to 1.
Automatic mixed precision is also available with the `--amp` flag. [Mixed precision](https://arxiv.org/abs/1710.03740) allows the model to use less memory and to be faster on recent GPUs by using FP16 arithmetic. Enabling AMP is recommended.
### Prediction
After training your model and saving it to `MODEL.pth`, you can easily test the output masks on your images via the CLI.
To predict a single image and save it:
`python predict.py -i image.jpg -o output.jpg`
To predict a multiple images and show them without saving them:
`python predict.py -i image1.jpg image2.jpg --viz --no-save`
```console
> python predict.py -h
usage: predict.py [-h] [--model FILE] --input INPUT [INPUT ...]
[--output INPUT [INPUT ...]] [--viz] [--no-save]
[--mask-threshold MASK_THRESHOLD] [--scale SCALE]
Predict masks from input images
optional arguments:
-h, --help show this help message and exit
--model FILE, -m FILE
Specify the file in which the model is stored
--input INPUT [INPUT ...], -i INPUT [INPUT ...]
Filenames of input images
--output INPUT [INPUT ...], -o INPUT [INPUT ...]
Filenames of output images
--viz, -v Visualize the images as they are processed
--no-save, -n Do not save the output masks
--mask-threshold MASK_THRESHOLD, -t MASK_THRESHOLD
Minimum probability value to consider a mask pixel white
--scale SCALE, -s SCALE
Scale factor for the input images
```
You can specify which model file to use with `--model MODEL.pth`.
## Weights & Biases
The training progress can be visualized in real-time using [Weights & Biases](https://wandb.ai/). Loss curves, validation curves, weights and gradient histograms, as well as predicted masks are logged to the platform.
When launching a training, a link will be printed in the console. Click on it to go to your dashboard. If you have an existing W&B account, you can link it
by setting the `WANDB_API_KEY` environment variable. If not, it will create an anonymous run which is automatically deleted after 7 days.
## Pretrained model
A [pretrained model](https://github.com/milesial/Pytorch-UNet/releases/tag/v3.0) is available for the Carvana dataset. It can also be loaded from torch.hub:
```python
net = torch.hub.load('milesial/Pytorch-UNet', 'unet_carvana', pretrained=True, scale=0.5)
```
Available scales are 0.5 and 1.0.
## Data
The Carvana data is available on the [Kaggle website](https://www.kaggle.com/c/carvana-image-masking-challenge/data).
You can also download it using the helper script:
```
bash scripts/download_data.sh
```
The input images and target masks should be in the `data/imgs` and `data/masks` folders respectively (note that the `imgs` and `masks` folder should not contain any sub-folder or any other files, due to the greedy data-loader). For Carvana, images are RGB and masks are black and white.
You can use your own dataset as long as you make sure it is loaded properly in `utils/data_loading.py`.
---
Original paper by Olaf Ronneberger, Philipp Fischer, Thomas Brox:
[U-Net: Convolutional Networks for Biomedical Image Segmentation](https://arxiv.org/abs/1505.04597)
![network architecture](https://i.imgur.com/jeDVpqF.png)
数据集下载链接<br>
数据集下载链接<br>
链接:https://pan.baidu.com/s/1rIZIiR_bNWyeE7YxUTXSzQ?pwd=DTAI <br>
提取码:DTAI
import torch
import torch
import torch.nn.functional as F
from tqdm import tqdm
from utils.dice_score import multiclass_dice_coeff, dice_coeff
@torch.inference_mode() # torch.inference_mode() 装饰器将函数 evaluate 定义为推理模式,以提高性能
def evaluate(net, dataloader, device, amp):
net.eval()
num_val_batches = len(dataloader)
dice_score = 0
# iterate over the validation set
with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):
for batch in tqdm(dataloader, total=num_val_batches, desc='Validation round', unit='batch', leave=False):
image, mask_true = batch['image'], batch['mask']
# move images and labels to correct device and type
image = image.to(device=device, dtype=torch.float32, memory_format=torch.channels_last)
mask_true = mask_true.to(device=device, dtype=torch.long)
# predict the mask
mask_pred = net(image)
if net.n_classes == 1: # 如果是单类别分割(net.n_classes == 1),将真实掩码和预测掩码进行二值化,然后计算 Dice 分数
assert mask_true.min() >= 0 and mask_true.max() <= 1, 'True mask indices should be in [0, 1]'
mask_pred = (F.sigmoid(mask_pred) > 0.5).float()
# compute the Dice score
dice_score += dice_coeff(mask_pred, mask_true, reduce_batch_first=False)
else: # 如果是多类别分割,将真实掩码和预测掩码转换为独热编码形式,并计算 Dice 分数
assert mask_true.min() >= 0 and mask_true.max() < net.n_classes, 'True mask indices should be in [0, n_classes['
# convert to one-hot format
mask_true = F.one_hot(mask_true, net.n_classes).permute(0, 3, 1, 2).float()
mask_pred = F.one_hot(mask_pred.argmax(dim=1), net.n_classes).permute(0, 3, 1, 2).float()
# compute the Dice score, ignoring background
dice_score += multiclass_dice_coeff(mask_pred[:, 1:], mask_true[:, 1:], reduce_batch_first=False)
net.train()
return dice_score / max(num_val_batches, 1)
# hubconf.py 文件的作用是充当 PyTorch Hub 的配置文件,用于定义如何在 PyTorch Hub 中加载和使用该 UNet 模型
# hubconf.py 文件的作用是充当 PyTorch Hub 的配置文件,用于定义如何在 PyTorch Hub 中加载和使用该 UNet 模型
import torch
from unet import UNet as _UNet
def unet_carvana(pretrained=False, scale=0.5):
"""
在Carvana数据集上训练的UNet模型(https://www.kaggle.com/c/carvana-image-masking-challenge/data)。
在预测时,将scale设置为0.5(50%)。
"""
net = _UNet(n_channels=3, n_classes=2, bilinear=False)
if pretrained:
if scale == 0.5:
checkpoint = 'https://github.com/milesial/Pytorch-UNet/releases/download/v3.0/unet_carvana_scale0.5_epoch2.pth'
elif scale == 1.0:
checkpoint = 'https://github.com/milesial/Pytorch-UNet/releases/download/v3.0/unet_carvana_scale1.0_epoch2.pth'
else:
raise RuntimeError('Only 0.5 and 1.0 scales are available')
state_dict = torch.hub.load_state_dict_from_url(checkpoint, progress=True)
if 'mask_values' in state_dict:
state_dict.pop('mask_values')
net.load_state_dict(state_dict)
return net
import logging
import logging
import os
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from utils.data_loading import BasicDataset
from unet import UNet
from utils.utils import plot_img_and_mask
def predict_img(net,
full_img,
device,
scale_factor=1,
out_threshold=0.5): # 置信度阈值
net.eval()
img = torch.from_numpy(BasicDataset.preprocess(None, full_img, scale_factor, is_mask=False))
img = img.unsqueeze(0)
img = img.to(device=device, dtype=torch.float32)
with torch.no_grad():
output = net(img).cpu()
output = F.interpolate(output, (full_img.size[1], full_img.size[0]), mode='bilinear')
if net.n_classes > 1:
mask = output.argmax(dim=1)
else:
mask = torch.sigmoid(output) > out_threshold
return mask[0].long().squeeze().numpy()
def get_output_filename(input_file):
return f'{os.path.splitext(input_file)[0]}_OUT.png'
def mask_to_image(mask: np.ndarray, mask_values):
if isinstance(mask_values[0], list):
out = np.zeros((mask.shape[-2], mask.shape[-1], len(mask_values[0])), dtype=np.uint8)
elif mask_values == [0, 1]:
out = np.zeros((mask.shape[-2], mask.shape[-1]), dtype=bool)
else:
out = np.zeros((mask.shape[-2], mask.shape[-1]), dtype=np.uint8)
if mask.ndim == 3:
mask = np.argmax(mask, axis=0)
for i, v in enumerate(mask_values):
out[mask == i] = v
return Image.fromarray(out)
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
# 硬编码参数
model_path = 'checkpoints/checkpoint_epoch5.pth'
input_files = ['data/imgs/0cdf5b5d0ce1_03.jpg'] # 在此处输入图片路径
output_dir = 'output' # 输出文件夹名称
viz = False # 是否可视化结果
no_save = False # 是否保存输出的掩膜图像
mask_threshold = 0.5 # 掩膜像素值阈值
scale_factor = 0.5 # 输入图像的缩放因子
bilinear = False # 是否使用双线性插值
num_classes = 2 # 类别数
net = UNet(n_channels=3, n_classes=num_classes, bilinear=bilinear)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Loading model {model_path}')
logging.info(f'Using device {device}')
net.to(device=device)
state_dict = torch.load(model_path, map_location=device)
mask_values = state_dict.pop('mask_values', [0, 1])
net.load_state_dict(state_dict)
logging.info('Model loaded!')
for i, filename in enumerate(input_files):
logging.info(f'Predicting image {filename} ...')
img = Image.open(filename)
mask = predict_img(net=net,
full_img=img,
device=device,
scale_factor=scale_factor,
out_threshold=mask_threshold)
if not no_save:
os.makedirs(output_dir, exist_ok=True)
out_filename = os.path.join(output_dir, get_output_filename(os.path.basename(filename)))
result = mask_to_image(mask, mask_values)
result.save(out_filename)
logging.info(f'Mask saved to {out_filename}')
if viz:
logging.info(f'Visualizing results for image {filename} ...')
plot_img_and_mask(img, mask)
# 输出预测图片路径
logging.info(f'Prediction saved to {out_filename}')
\ No newline at end of file
matplotlib==3.6.2
matplotlib==3.6.2
numpy==1.23.5
Pillow==9.3.0
tqdm==4.64.1
wandb==0.13.5
""" Full assembly of the parts to form the complete network """
""" Full assembly of the parts to form the complete network """
from .unet_parts import *
class UNet(nn.Module):
def __init__(self, n_channels, n_classes, bilinear=False):
super(UNet, self).__init__()
# 初始化UNet类的属性
self.n_channels = n_channels # 输入图像的通道数
self.n_classes = n_classes # 分类的类别数
self.bilinear = bilinear # 是否使用双线性插值的上采样
self.inc = (DoubleConv(n_channels, 64)) # 输入模块
# 下采样模块
self.down1 = (Down(64, 128))
self.down2 = (Down(128, 256))
self.down3 = (Down(256, 512))
# 为了在定义 down4 模块时,根据是否使用双线性插值来确定输出通道数的值。如果使用双线性插值,则输出通道数为 1024 // 2 = 512
factor = 2 if bilinear else 1
self.down4 = (Down(512, 1024 // factor))
# 上采样模块
self.up1 = (Up(1024, 512 // factor, bilinear))
self.up2 = (Up(512, 256 // factor, bilinear))
self.up3 = (Up(256, 128 // factor, bilinear))
self.up4 = (Up(128, 64, bilinear))
self.outc = (OutConv(64, n_classes)) # 输出模块
# 执行
def forward(self, x):
# 类里面的方法继承自 nn.Module,forward就不用显式的调用
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return logits
# 使用checkpoint对输入模块进行优化
def use_checkpointing(self):
self.inc = torch.utils.checkpoint(self.inc)
self.down1 = torch.utils.checkpoint(self.down1)
self.down2 = torch.utils.checkpoint(self.down2)
self.down3 = torch.utils.checkpoint(self.down3)
self.down4 = torch.utils.checkpoint(self.down4)
self.up1 = torch.utils.checkpoint(self.up1)
self.up2 = torch.utils.checkpoint(self.up2)
self.up3 = torch.utils.checkpoint(self.up3)
self.up4 = torch.utils.checkpoint(self.up4)
self.outc = torch.utils.checkpoint(self.outc)
\ No newline at end of file
""" Parts of the U-Net model """
""" Parts of the U-Net model """
import torch
import torch.nn as nn
import torch.nn.functional as F
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
# (卷积 => [批归一化] => ReLU) * 2
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
# 最大池化下采样,然后双卷积
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x):
return self.maxpool_conv(x)
class Up(nn.Module):
"""Upscaling then double conv"""
# 上采样,然后双卷积
def __init__(self, in_channels, out_channels, bilinear=True):
super().__init__()
# 如果使用双线性插值,则使用普通的卷积来减少通道数
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else:
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2): # x1通过上采样操作得到,x2是来自下采样路径的特征图
x1 = self.up(x1)
# 输入是CHW格式
diffY = x2.size()[2] - x1.size()[2] # x1和x2在高度上的尺寸差异
diffX = x2.size()[3] - x1.size()[3] # x1和x2在宽度上的尺寸差异
# 差异值(diffY和diffX)用于对x1进行填充操作
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
# if you have padding issues, see
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
x = torch.cat([x2, x1], dim=1) # 两种特征图拼接
return self.conv(x)
class OutConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(OutConv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
return self.conv(x)
++ "b/2-\345\272\224\347\224\250/2.08-\350\257\255\344\271\211\345\210\206\345\211\262/2.08.3-U-net\345\214\273\345\255\246\345\233\276\345\203\217\345\210\206\345\211\262/pytorch-unet/utils/__init__.py"
import logging
import logging
import numpy as np
import torch
from PIL import Image
from functools import lru_cache
from functools import partial
from itertools import repeat
from multiprocessing import Pool
from os import listdir
from os.path import splitext, isfile, join
from pathlib import Path
from torch.utils.data import Dataset
from tqdm import tqdm
# 可以加载.npy、.pt、.pth或其他格式的图像文件,并返回一个PIL图像对象。
def load_image(filename):
ext = splitext(filename)[1]
if ext == '.npy':
return Image.fromarray(np.load(filename))
elif ext in ['.pt', '.pth']:
return Image.fromarray(torch.load(filename).numpy())
else:
return Image.open(filename)
# 定义了一个辅助函数unique_mask_values,用于确定掩膜文件中的唯一值。它通过传递索引(idx)、掩膜目录(mask_dir)和掩膜后缀(mask_suffix)来查找相应的掩膜文件。
def unique_mask_values(idx, mask_dir, mask_suffix):
mask_file = list(mask_dir.glob(idx + mask_suffix + '.*'))[0]
mask = np.asarray(load_image(mask_file))
if mask.ndim == 2:
return np.unique(mask)
elif mask.ndim == 3:
mask = mask.reshape(-1, mask.shape[-1])
return np.unique(mask, axis=0)
else:
raise ValueError(f'Loaded masks should have 2 or 3 dimensions, found {mask.ndim}')
class BasicDataset(Dataset):
def __init__(self, images_dir: str, mask_dir: str, scale: float = 1.0, mask_suffix: str = ''):
self.images_dir = Path(images_dir)
self.mask_dir = Path(mask_dir)
assert 0 < scale <= 1, 'Scale must be between 0 and 1'
self.scale = scale
self.mask_suffix = mask_suffix
self.ids = [splitext(file)[0] for file in listdir(images_dir) if isfile(join(images_dir, file)) and not file.startswith('.')]
if not self.ids:
raise RuntimeError(f'No input file found in {images_dir}, make sure you put your images there')
logging.info(f'Creating dataset with {len(self.ids)} examples')
logging.info('Scanning mask files to determine unique values')
with Pool() as p:
unique = list(tqdm(
p.imap(partial(unique_mask_values, mask_dir=self.mask_dir, mask_suffix=self.mask_suffix), self.ids),
total=len(self.ids)
))
self.mask_values = list(sorted(np.unique(np.concatenate(unique), axis=0).tolist()))
logging.info(f'Unique mask values: {self.mask_values}')
def __len__(self):
return len(self.ids)
@staticmethod
def preprocess(mask_values, pil_img, scale, is_mask):
w, h = pil_img.size
newW, newH = int(scale * w), int(scale * h)
assert newW > 0 and newH > 0, 'Scale is too small, resized images would have no pixel'
pil_img = pil_img.resize((newW, newH), resample=Image.NEAREST if is_mask else Image.BICUBIC)
img = np.asarray(pil_img)
if is_mask:
mask = np.zeros((newH, newW), dtype=np.int64)
for i, v in enumerate(mask_values):
if img.ndim == 2:
mask[img == v] = i
else:
mask[(img == v).all(-1)] = i
return mask
else:
if img.ndim == 2:
img = img[np.newaxis, ...]
else:
img = img.transpose((2, 0, 1))
if (img > 1).any():
img = img / 255.0
return img
def __getitem__(self, idx):
name = self.ids[idx]
mask_file = list(self.mask_dir.glob(name + self.mask_suffix + '.*'))
img_file = list(self.images_dir.glob(name + '.*'))
assert len(img_file) == 1, f'Either no image or multiple images found for the ID {name}: {img_file}'
assert len(mask_file) == 1, f'Either no mask or multiple masks found for the ID {name}: {mask_file}'
mask = load_image(mask_file[0])
img = load_image(img_file[0])
assert img.size == mask.size, \
f'Image and mask {name} should be the same size, but are {img.size} and {mask.size}'
img = self.preprocess(self.mask_values, img, self.scale, is_mask=False)
mask = self.preprocess(self.mask_values, mask, self.scale, is_mask=True)
return {
'image': torch.as_tensor(img.copy()).float().contiguous(),
'mask': torch.as_tensor(mask.copy()).long().contiguous()
}
class CarvanaDataset(BasicDataset):
def __init__(self, images_dir, mask_dir, scale=1):
super().__init__(images_dir, mask_dir, scale, mask_suffix='_mask')
import torch
import torch
from torch import Tensor
def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6):
# Average of Dice coefficient for all batches, or for a single mask
assert input.size() == target.size()
assert input.dim() == 3 or not reduce_batch_first
sum_dim = (-1, -2) if input.dim() == 2 or not reduce_batch_first else (-1, -2, -3)
inter = 2 * (input * target).sum(dim=sum_dim)
sets_sum = input.sum(dim=sum_dim) + target.sum(dim=sum_dim)
sets_sum = torch.where(sets_sum == 0, inter, sets_sum)
dice = (inter + epsilon) / (sets_sum + epsilon)
return dice.mean()
def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6):
# Average of Dice coefficient for all classes
return dice_coeff(input.flatten(0, 1), target.flatten(0, 1), reduce_batch_first, epsilon)
def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False):
# Dice loss (objective to minimize) between 0 and 1
fn = multiclass_dice_coeff if multiclass else dice_coeff
return 1 - fn(input, target, reduce_batch_first=True)
# 可视化输入图像和掩膜
# 可视化输入图像和掩膜
import matplotlib.pyplot as plt
def plot_img_and_mask(img, mask):
classes = mask.max() + 1
fig, ax = plt.subplots(1, classes + 1)
ax[0].set_title('Input image')
ax[0].imshow(img)
for i in range(classes):
ax[i + 1].set_title(f'Mask (class {i + 1})')
ax[i + 1].imshow(mask == i)
plt.xticks([]), plt.yticks([])
plt.show()
++ "b/2-\345\272\224\347\224\250/2.09-\346\226\207\346\234\254\345\244\204\347\220\206/.gitkeep"
++ "b/2-\345\272\224\347\224\250/2.09-\346\226\207\346\234\254\345\244\204\347\220\206/2.09.3-\345\237\272\344\272\216RCNN\345\256\236\347\216\260\346\226\207\346\234\254\345\210\206\347\261\273-RCNN_Text_Classfication/README.md"
from collections import Counter # 从collections模块中导入Counter类
from collections import Counter # 从collections模块中导入Counter类
def build_dictionary(texts, vocab_size):
counter = Counter() # 创建Counter对象用于统计单词频率
SPECIAL_TOKENS = ['<PAD>', '<UNK>'] # 定义特殊标记:填充符和未知词
for word in texts:
counter.update(word) # 更新计数器,统计每个单词的出现频率
# 获取出现频率最高的单词,并与特殊标记合并形成最终词汇表
words = [word for word, count in counter.most_common(vocab_size - len(SPECIAL_TOKENS))]
words = SPECIAL_TOKENS + words
# 创建单词到索引的映射
word2idx = {word: idx for idx, word in enumerate(words)}
return word2idx # 返回单词到索引的映射
++ "b/2-\345\272\224\347\224\250/2.09-\346\226\207\346\234\254\345\244\204\347\220\206/2.09.3-\345\237\272\344\272\216RCNN\345\256\236\347\216\260\346\226\207\346\234\254\345\210\206\347\261\273-RCNN_Text_Classfication/data/.gitkeep"
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
import torch
import torch
from torch.utils.data import Dataset
class CustomTextDataset(Dataset):
def __init__(self, texts, labels, dictionary):
# 未知的词汇对应的索引是1 (<UNK>)
self.x = [[dictionary.get(token, 1) for token in token_list] for token_list in texts]
self.y = labels
def __len__(self):
"""返回数据集的长度"""
return len(self.x)
def __getitem__(self, idx):
"""返回给定索引处的一条数据"""
return self.x[idx], self.y[idx]
def collate_fn(data, args, pad_idx=0):
"""填充"""
texts, labels = zip(*data)
texts = [s + [pad_idx] * (args.max_len - len(s)) if len(s) < args.max_len else s[:args.max_len] for s in texts]
return torch.LongTensor(texts), torch.LongTensor(labels)
import os # 导入操作系统接口模块
import os # 导入操作系统接口模块
import argparse # 导入用于解析命令行参数的模块
import logging # 导入日志模块
import random # 导入随机数模块
import numpy as np # 导入NumPy库,用于科学计算
import torch # 导入PyTorch库
from torch.utils.data import DataLoader, random_split # 导入PyTorch的数据加载器和数据集拆分功能
from build_vocab import build_dictionary # 从build_vocab模块中导入build_dictionary函数
from dataset import CustomTextDataset, collate_fn # 从dataset模块中导入CustomTextDataset类和collate_fn函数
from model import RCNN # 从model模块中导入RCNN类
from trainer import train, evaluate # 从trainer模块中导入train和evaluate函数
from utils import read_file # 从utils模块中导入read_file函数
logging.basicConfig(format='%(asctime)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) # 配置日志格式和级别
logger = logging.getLogger(__name__) # 创建一个日志记录器对象
def set_seed(args):
random.seed(args.seed) # 设置Python随机数生成器的种子
np.random.seed(args.seed) # 设置NumPy随机数生成器的种子
torch.manual_seed(args.seed) # 设置PyTorch随机数生成器的种子
if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed) # 如果有多个GPU设备,设置所有GPU的随机数生成器种子
def main(args):
model = RCNN(vocab_size=args.vocab_size,
embedding_dim=args.embedding_dim,
hidden_size=args.hidden_size,
hidden_size_linear=args.hidden_size_linear,
class_num=args.class_num,
dropout=args.dropout).to(args.device) # 初始化RCNN模型并将其移动到指定设备
if args.n_gpu > 1:
model = torch.nn.DataParallel(model, dim=0) # 如果有多个GPU设备,使用DataParallel进行并行计算
train_texts, train_labels = read_file(args.train_file_path) # 读取训练数据文件
word2idx = build_dictionary(train_texts, vocab_size=args.vocab_size) # 构建词汇表
logger.info('Dictionary Finished!') # 记录词汇表构建完成的日志信息
full_dataset = CustomTextDataset(train_texts, train_labels, word2idx) # 创建自定义文本数据集
num_train_data = len(full_dataset) - args.num_val_data # 计算训练数据的数量
train_dataset, val_dataset = random_split(full_dataset, [num_train_data, args.num_val_data]) # 随机拆分数据集为训练集和验证集
train_dataloader = DataLoader(dataset=train_dataset,
collate_fn=lambda x: collate_fn(x, args),
batch_size=args.batch_size,
shuffle=True) # 创建训练数据加载器
valid_dataloader = DataLoader(dataset=val_dataset,
collate_fn=lambda x: collate_fn(x, args),
batch_size=args.batch_size,
shuffle=True) # 创建验证数据加载器
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) # 初始化Adam优化器
train(model, optimizer, train_dataloader, valid_dataloader, args) # 调用训练函数
logger.info('******************** Train Finished ********************') # 记录训练完成的日志信息
# 测试
if args.test_set:
test_texts, test_labels = read_file(args.test_file_path) # 读取测试数据文件
test_dataset = CustomTextDataset(test_texts, test_labels, word2idx) # 创建自定义测试数据集
test_dataloader = DataLoader(dataset=test_dataset,
collate_fn=lambda x: collate_fn(x, args),
batch_size=args.batch_size,
shuffle=True) # 创建测试数据加载器
model.load_state_dict(torch.load(os.path.join(args.model_save_path, "best.pt"))) # 加载最优模型参数
_, accuracy, precision, recall, f1, cm = evaluate(model, test_dataloader, args) # 调用评估函数并获取评估指标
logger.info('-'*50)
logger.info(f'|* TEST SET *| |ACC| {accuracy:>.4f} |PRECISION| {precision:>.4f} |RECALL| {recall:>.4f} |F1| {f1:>.4f}') # 记录测试集的评估结果
logger.info('-'*50)
logger.info('---------------- CONFUSION MATRIX ----------------') # 记录混淆矩阵信息
for i in range(len(cm)):
logger.info(cm[i])
logger.info('--------------------------------------------------')
if __name__ == "__main__":
parser = argparse.ArgumentParser() # 创建参数解析器
parser.add_argument('--seed', type=int, default=42) # 添加随机种子参数
parser.add_argument('--test_set', action='store_true', default=False) # 添加测试集参数
# 数据相关参数
parser.add_argument("--train_file_path", type=str, default="./data/train.csv") # 训练数据文件路径
parser.add_argument("--test_file_path", type=str, default="./data/test.csv") # 测试数据文件路径
parser.add_argument("--model_save_path", type=str, default="./model_saved") # 模型保存路径
parser.add_argument("--num_val_data", type=int, default=10000) # 验证集数据数量
parser.add_argument("--max_len", type=int, default=64) # 最大序列长度
parser.add_argument("--batch_size", type=int, default=64) # 批处理大小
# 模型相关参数
parser.add_argument("--vocab_size", type=int, default=8000) # 词汇表大小
parser.add_argument("--embedding_dim", type=int, default=300) # 词向量维度
parser.add_argument("--hidden_size", type=int, default=512) # 隐藏层大小
parser.add_argument("--hidden_size_linear", type=int, default=512) # 线性层隐藏层大小
parser.add_argument("--class_num", type=int, default=4) # 类别数量
parser.add_argument("--dropout", type=float, default=0.0) # Dropout概率
# 训练相关参数
parser.add_argument("--epochs", type=int, default=10) # 训练轮数
parser.add_argument("--lr", type=float, default=3e-4) # 学习率
args = parser.parse_args() # 解析命令行参数
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 设置设备为GPU或CPU
args.n_gpu = torch.cuda.device_count() # 获取可用GPU的数量
set_seed(args) # 设置随机种子
main(args) # 调用主函数
import torch
import torch
import torch.nn as nn
import torch.nn.functional as F
class RCNN(nn.Module):
"""
Recurrent Convolutional Neural Networks for Text Classification (2015)
"""
def __init__(self, vocab_size, embedding_dim, hidden_size, hidden_size_linear, class_num, dropout):
super(RCNN, self).__init__()
# 嵌入层,将输入的词索引转换为词向量
self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
# 双向LSTM层,用于提取序列信息
self.lstm = nn.LSTM(embedding_dim, hidden_size, batch_first=True, bidirectional=True, dropout=dropout)
# 线性变换层,将LSTM输出与嵌入拼接后的向量进行线性变换
self.W = nn.Linear(embedding_dim + 2*hidden_size, hidden_size_linear)
# Tanh激活函数
self.tanh = nn.Tanh()
# 最后的全连接层,将特征向量映射到类别空间
self.fc = nn.Linear(hidden_size_linear, class_num)
def forward(self, x):
# x 的尺寸为 |批次大小, 序列长度|
# 嵌入层,x_emb 的尺寸为 |批次大小, 序列长度, 嵌入维度|
x_emb = self.embedding(x)
# LSTM层,output 的尺寸为 |批次大小, 序列长度, 2*隐藏层大小|
output, _ = self.lstm(x_emb)
# 将 LSTM 的输出和原始嵌入拼接,output 的尺寸为 |批次大小, 序列长度, 嵌入维度 + 2*隐藏层大小|
output = torch.cat([output, x_emb], 2)
# 线性变换和 Tanh 激活,output 的尺寸从 |批次大小, 序列长度, 线性隐藏层大小| 变为 |批次大小, 线性隐藏层大小, 序列长度|
output = self.tanh(self.W(output)).transpose(1, 2)
# 通过最大池化层,将 output 的尺寸从 |批次大小, 线性隐藏层大小, 序列长度| 变为 |批次大小, 线性隐藏层大小|
output = F.max_pool1d(output, output.size(2)).squeeze(2)
# 最后通过全连接层,output 的尺寸为 |批次大小, 类别数|
output = self.fc(output)
return output
import os
import os
import logging
import torch
import torch.nn.functional as F
from utils import metrics
# 配置日志记录的格式和级别
logging.basicConfig(format='%(asctime)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO)
logger = logging.getLogger(__name__)
# 定义训练函数
def train(model, optimizer, train_dataloader, valid_dataloader, args):
best_f1 = 0
logger.info('Start Training!') # 开始训练
for epoch in range(1, args.epochs + 1):
model.train() # 将模型设置为训练模式
for step, (x, y) in enumerate(train_dataloader):
x, y = x.to(args.device), y.to(args.device) # 将数据移至指定设备(如GPU)
pred = model(x) # 前向传播,获取预测结果
loss = F.cross_entropy(pred, y) # 计算交叉熵损失
optimizer.zero_grad() # 清零梯度
loss.backward() # 反向传播,计算梯度
optimizer.step() # 更新模型参数
if (step + 1) % 200 == 0:
logger.info(f'|EPOCHS| {epoch:>}/{args.epochs} |STEP| {step + 1:>4}/{len(train_dataloader)} |LOSS| {loss.item():>.4f}')
avg_loss, accuracy, _, _, f1, _ = evaluate(model, valid_dataloader, args) # 在验证集上评估模型
logger.info('-' * 50)
logger.info(f'|* VALID SET *| |VAL LOSS| {avg_loss:>.4f} |ACC| {accuracy:>.4f} |F1| {f1:>.4f}')
logger.info('-' * 50)
if f1 > best_f1:
best_f1 = f1
logger.info(f'Saving best model... F1 score is {best_f1:>.4f}')
if not os.path.isdir(args.model_save_path):
os.mkdir(args.model_save_path)
torch.save(model.state_dict(), os.path.join(args.model_save_path, "best.pt")) # 保存最佳模型
logger.info('Model saved!')
# 定义评估函数
def evaluate(model, valid_dataloader, args):
with torch.no_grad(): # 禁用梯度计算
model.eval() # 将模型设置为评估模式
losses, correct = 0, 0
y_hats, targets = [], []
for x, y in valid_dataloader:
x, y = x.to(args.device), y.to(args.device)
pred = model(x)
loss = F.cross_entropy(pred, y) # 计算交叉熵损失
losses += loss.item()
y_hat = torch.max(pred, 1)[1] # 获取预测结果的最大值所在的索引
y_hats += y_hat.tolist()
targets += y.tolist()
correct += (y_hat == y).sum().item() # 计算正确预测的数量
avg_loss, accuracy, precision, recall, f1, cm = metrics(valid_dataloader, losses, correct, y_hats, targets) # 计算评价指标
return avg_loss, accuracy, precision, recall, f1, cm
import pandas as pd
import pandas as pd
import re
import nltk
from nltk.tokenize import word_tokenize
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix
nltk.download('punkt') # 下载nltk中的punkt模型,用于分词
def read_file(file_path):
"""
读取AG NEWS数据集的函数
"""
data = pd.read_csv(file_path, names=["class", "title", "description"]) # 读取CSV文件,并指定列名
texts = list(data['title'].values + ' ' + data['description'].values) # 将标题和描述连接成一个文本
texts = [word_tokenize(preprocess_text(sentence)) for sentence in texts] # 对每个文本进行预处理和分词
labels = [label-1 for label in list(data['class'].values)] # 标签从1~4转为0~3
return texts, labels # 返回处理后的文本和标签
def preprocess_text(string):
"""
预处理文本的函数
"""
string = string.lower() # 转换为小写
string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string) # 移除特定字符,保留字母数字及某些标点
string = re.sub(r"\'s", " \'s", string) # 处理所有者's
string = re.sub(r"\'ve", " \'ve", string) # 处理've
string = re.sub(r"n\'t", " n\'t", string) # 处理否定的n't
string = re.sub(r"\'re", " \'re", string) # 处理're
string = re.sub(r"\'d", " \'d", string) # 处理'd
string = re.sub(r"\'ll", " \'ll", string) # 处理'll
string = re.sub(r",", " , ", string) # 处理逗号
string = re.sub(r"!", " ! ", string) # 处理感叹号
string = re.sub(r"\(", " \( ", string) # 处理左括号
string = re.sub(r"\)", " \) ", string) # 处理右括号
string = re.sub(r"\?", " \? ", string) # 处理问号
string = re.sub(r"\s{2,}", " ", string) # 将多个空格替换为一个空格
return string.strip() # 去除字符串两端的空格
def metrics(dataloader, losses, correct, y_hats, targets):
"""
计算模型性能指标的函数
"""
avg_loss = losses / len(dataloader) # 计算平均损失
accuracy = correct / len(dataloader.dataset) * 100 # 计算准确率
precision = precision_score(targets, y_hats, average='macro') # 计算精确率
recall = recall_score(targets, y_hats, average='macro') # 计算召回率
f1 = f1_score(targets, y_hats, average='macro') # 计算F1分数
cm = confusion_matrix(targets, y_hats) # 计算混淆矩阵
return avg_loss, accuracy, precision, recall, f1, cm # 返回所有计算的指标
++ "b/2-\345\272\224\347\224\250/2.10-\350\257\255\351\237\263\345\244\204\347\220\206/.gitkeep"
++ "b/2-\345\272\224\347\224\250/2.10-\350\257\255\351\237\263\345\244\204\347\220\206/2.10.3-DeepSpeech2\347\253\257\345\210\260\347\253\257\344\270\255\346\226\207\350\257\255\351\237\263\350\257\206\345\210\253\346\250\241\345\236\213\345\256\236\347\216\260/.gitkeep"
This source diff could not be displayed because it is too large. You can view the blob instead.
++ "b/2-\345\272\224\347\224\250/2.11-\345\206\263\347\255\226\344\270\216\350\247\204\345\210\222/.gitkeep"
++ "b/2-\345\272\224\347\224\250/2.11-\345\206\263\347\255\226\344\270\216\350\247\204\345\210\222/2.11.4-\351\251\254\345\260\224\345\217\257\345\244\253\350\247\243\345\206\263\346\234\200\344\274\230\345\214\226\351\227\256\351\242\230/README.md"
++ "b/2-\345\272\224\347\224\250/2.12-\347\233\256\346\240\207\346\243\200\346\265\213/2.12.01-YOLO V1\346\250\241\345\236\213\350\257\246\350\247\243\344\270\216\344\273\243\347\240\201/.gitkeep"
YOLOV1代码下载链接<br>
YOLOV1代码下载链接<br>
链接:https://pan.baidu.com/s/1bZZyT5ld2kn3ylVBmIPjeA?pwd=m1s4<br>
提取码:m1s4<br>
<br>
<br>
YOLOV1与YOLOV3数据集下载链接<br>
链接:https://pan.baidu.com/s/1wLuwIb29cu4kYcD_bHeU_Q?pwd=qjeh<br>
提取码:qjeh
\ No newline at end of file
++ "b/2-\345\272\224\347\224\250/2.12-\347\233\256\346\240\207\346\243\200\346\265\213/2.12.02-YOLO V2\346\250\241\345\236\213\350\257\246\350\247\243\344\270\216\344\273\243\347\240\201/.gitkeep"
YOLOV2代码下载链接<br>
YOLOV2代码下载链接<br>
链接:链接: https://pan.baidu.com/s/1_rEQjEhyYuOgFxIL069e0w?pwd=j8bc<br>
提取码:j8bc<br>
<br>
<br>
YOLOV2数据集下载链接<br>
链接:链接: https://pan.baidu.com/s/1ndJVDT_3QUErvcZMdBpkOQ?pwd=rknt<br>
提取码:rknt
\ No newline at end of file
++ "b/2-\345\272\224\347\224\250/2.12-\347\233\256\346\240\207\346\243\200\346\265\213/2.12.03-YOLO V3\346\250\241\345\236\213\350\257\246\350\247\243\344\270\216\344\273\243\347\240\201/.gitkeep"
YOLOV3代码下载链接<br>
YOLOV3代码下载链接<br>
链接:https://pan.baidu.com/s/1AOdYU1XtAKsBFytw4OxeXQ?pwd=uxsp<br>
提取码:uxsp<br>
<br>
<br>
YOLOV1与YOLOV3数据集下载链接<br>
链接:https://pan.baidu.com/s/1wLuwIb29cu4kYcD_bHeU_Q?pwd=qjeh<br>
提取码:qjeh
\ No newline at end of file
++ "b/2-\345\272\224\347\224\250/2.12-\347\233\256\346\240\207\346\243\200\346\265\213/2.12.04-YOLO V4\346\250\241\345\236\213\350\257\246\350\247\243\344\270\216\344\273\243\347\240\201/.gitkeep"
YOLOV4代码下载链接<br>
YOLOV4代码下载链接<br>
链接:https://pan.baidu.com/s/1K9KQJsp4ZVbrLNwA964rZA?pwd=m7qk<br>
提取码:m7qk<br>
++ "b/2-\345\272\224\347\224\250/2.12-\347\233\256\346\240\207\346\243\200\346\265\213/2.12.05-YOLO V5\346\250\241\345\236\213\350\257\246\350\247\243\344\270\216\344\273\243\347\240\201/.gitkeep"
YOLOV5代码下载链接<br>
YOLOV5代码下载链接<br>
链接:https://pan.baidu.com/s/1l7FJQH7gVIX3SH-p1R4hCw?pwd=efgy<br>
提取码:efgy<br>
\ No newline at end of file
++ "b/2-\345\272\224\347\224\250/2.12-\347\233\256\346\240\207\346\243\200\346\265\213/2.12.06-YOLO V6\346\250\241\345\236\213\350\257\246\350\247\243\344\270\216\344\273\243\347\240\201/.gitkeep"
YOLOV6代码下载链接<br>
YOLOV6代码下载链接<br>
链接:https://pan.baidu.com/s/1PMXuSnt4RCavWQaF9fbkYw<br>
提取码:hzey<br>
\ No newline at end of file
++ "b/2-\345\272\224\347\224\250/2.12-\347\233\256\346\240\207\346\243\200\346\265\213/2.12.07-YOLO V7\346\250\241\345\236\213\350\257\246\350\247\243\344\270\216\344\273\243\347\240\201/.gitkeep"
YOLOV7代码下载链接<br>
YOLOV7代码下载链接<br>
链接:https://pan.baidu.com/s/1YoXnH9zFsf4UMyBP_taEtA<br>
提取码:nbjr<br>
\ No newline at end of file
++ "b/2-\345\272\224\347\224\250/2.12-\347\233\256\346\240\207\346\243\200\346\265\213/2.12.08-YOLO V8\346\250\241\345\236\213\350\257\246\350\247\243\344\270\216\344\273\243\347\240\201/.gitkeep"
YOLOV8代码下载链接<br>
YOLOV8代码下载链接<br>
链接:https://pan.baidu.com/s/1yJd4LL2ufCaKTsUU3Kgvaw?pwd=zccx<br>
提取码:zccx<br>
\ No newline at end of file
++ "b/2-\345\272\224\347\224\250/2.12-\347\233\256\346\240\207\346\243\200\346\265\213/2.12.09-YOLO V9\346\250\241\345\236\213\350\257\246\350\247\243\344\270\216\344\273\243\347\240\201/.gitkeep"
YOLOV9代码下载链接<br>
YOLOV9代码下载链接<br>
链接:https://pan.baidu.com/s/1ngu92yZC-Y0eDNizUamQ9w?pwd=5ck2<br>
提取码:5ck2<br>
\ No newline at end of file
++ "b/2-\345\272\224\347\224\250/2.12-\347\233\256\346\240\207\346\243\200\346\265\213/2.12.10-YOLOV3\345\256\236\347\216\260\347\233\256\346\240\207\346\243\200\346\265\213/.gitkeep"
代码下载链接<br>
代码下载链接<br>
链接:https://pan.baidu.com/s/1FgSjIn9Kp3TbQfh8iaiIdA?pwd=zsbh<br>
提取码:zsbh
\ No newline at end of file
++ "b/2-\345\272\224\347\224\250/2.13-\346\267\261\345\272\246\345\274\272\345\214\226\345\255\246\344\271\240/.gitkeep"
++ "b/2-\345\272\224\347\224\250/2.13-\346\267\261\345\272\246\345\274\272\345\214\226\345\255\246\344\271\240/2.13.1-Value-based Learning\347\275\221\347\273\234\350\257\246\350\247\243\344\270\216\344\273\243\347\240\201/.gitkeep"
import numpy as np
import numpy as np
import gym
import time
import turtle
assert gym.__version__ == "0.18.0", "[Version WARNING] please try `pip install gym==0.18.0`"
class FrozenLakeWapper(gym.Wrapper):
def __init__(self, env):
gym.Wrapper.__init__(self, env)
self.max_y = env.desc.shape[0]
self.max_x = env.desc.shape[1]
self.t = None
self.unit = 50
def draw_box(self, x, y, fillcolor='', line_color='gray'):
self.t.up()
self.t.goto(x * self.unit, y * self.unit)
self.t.color(line_color)
self.t.fillcolor(fillcolor)
self.t.setheading(90)
self.t.down()
self.t.begin_fill()
for _ in range(4):
self.t.forward(self.unit)
self.t.right(90)
self.t.end_fill()
def move_player(self, x, y):
self.t.up()
self.t.setheading(90)
self.t.fillcolor('red')
self.t.goto((x + 0.5) * self.unit, (y + 0.5) * self.unit)
def render(self):
if self.t == None:
self.t = turtle.Turtle()
self.wn = turtle.Screen()
self.wn.setup(self.unit * self.max_x + 100,
self.unit * self.max_y + 100)
self.wn.setworldcoordinates(0, 0, self.unit * self.max_x,
self.unit * self.max_y)
self.t.shape('circle')
self.t.width(2)
self.t.speed(0)
self.t.color('gray')
for i in range(self.desc.shape[0]):
for j in range(self.desc.shape[1]):
x = j
y = self.max_y - 1 - i
if self.desc[i][j] == b'S': # Start
self.draw_box(x, y, 'white')
elif self.desc[i][j] == b'F': # Frozen ice
self.draw_box(x, y, 'white')
elif self.desc[i][j] == b'G': # Goal
self.draw_box(x, y, 'yellow')
elif self.desc[i][j] == b'H': # Hole
self.draw_box(x, y, 'black')
else:
self.draw_box(x, y, 'white')
self.t.shape('turtle')
x_pos = self.s % self.max_x
y_pos = self.max_y - 1 - int(self.s / self.max_x)
self.move_player(x_pos, y_pos)
class CliffWalkingWapper(gym.Wrapper):
def __init__(self, env):
gym.Wrapper.__init__(self, env)
self.t = None
self.unit = 50
self.max_x = 12
self.max_y = 4
def draw_x_line(self, y, x0, x1, color='gray'):
assert x1 > x0
self.t.color(color)
self.t.setheading(0)
self.t.up()
self.t.goto(x0, y)
self.t.down()
self.t.forward(x1 - x0)
def draw_y_line(self, x, y0, y1, color='gray'):
assert y1 > y0
self.t.color(color)
self.t.setheading(90)
self.t.up()
self.t.goto(x, y0)
self.t.down()
self.t.forward(y1 - y0)
def draw_box(self, x, y, fillcolor='', line_color='gray'):
self.t.up()
self.t.goto(x * self.unit, y * self.unit)
self.t.color(line_color)
self.t.fillcolor(fillcolor)
self.t.setheading(90)
self.t.down()
self.t.begin_fill()
for i in range(4):
self.t.forward(self.unit)
self.t.right(90)
self.t.end_fill()
def move_player(self, x, y):
self.t.up()
self.t.setheading(90)
self.t.fillcolor('red')
self.t.goto((x + 0.5) * self.unit, (y + 0.5) * self.unit)
def render(self):
if self.t == None:
self.t = turtle.Turtle()
self.wn = turtle.Screen()
self.wn.setup(self.unit * self.max_x + 100,
self.unit * self.max_y + 100)
self.wn.setworldcoordinates(0, 0, self.unit * self.max_x,
self.unit * self.max_y)
self.t.shape('circle')
self.t.width(2)
self.t.speed(0)
self.t.color('gray')
for _ in range(2):
self.t.forward(self.max_x * self.unit)
self.t.left(90)
self.t.forward(self.max_y * self.unit)
self.t.left(90)
for i in range(1, self.max_y):
self.draw_x_line(
y=i * self.unit, x0=0, x1=self.max_x * self.unit)
for i in range(1, self.max_x):
self.draw_y_line(
x=i * self.unit, y0=0, y1=self.max_y * self.unit)
for i in range(1, self.max_x - 1):
self.draw_box(i, 0, 'black')
self.draw_box(self.max_x - 1, 0, 'yellow')
self.t.shape('turtle')
x_pos = self.s % self.max_x
y_pos = self.max_y - 1 - int(self.s / self.max_x)
self.move_player(x_pos, y_pos)
class QLearningAgent(object):
def __init__(self,
obs_n,
act_n,
learning_rate=0.01,
gamma=0.9,
e_greed=0.1):
self.act_n = act_n # 动作维度,有几个动作可选
self.lr = learning_rate # 学习率
self.gamma = gamma # reward的衰减率
self.epsilon = e_greed # 按一定概率随机选动作
self.Q = np.zeros((obs_n, act_n))
# 根据输入观察值,采样输出的动作值,带探索
def sample(self, obs):
if np.random.uniform(0, 1) < (1.0 - self.epsilon): #根据table的Q值选动作
action = self.predict(obs)
else:
action = np.random.choice(self.act_n) #有一定概率随机探索选取一个动作
return action
# 根据输入观察值,预测输出的动作值
def predict(self, obs):
Q_list = self.Q[obs, :]
maxQ = np.max(Q_list)
action_list = np.where(Q_list == maxQ)[0] # maxQ可能对应多个action
action = np.random.choice(action_list)
return action
# 学习方法,也就是更新Q-table的方法
def learn(self, obs, action, reward, next_obs, done):
""" off-policy
obs: 交互前的obs, s_t
action: 本次交互选择的action, a_t
reward: 本次动作获得的奖励r
next_obs: 本次交互后的obs, s_t+1
done: episode是否结束
"""
predict_Q = self.Q[obs, action]
if done:
target_Q = reward # 没有下一个状态了
else:
target_Q = reward + self.gamma * np.max(self.Q[next_obs, :]) # Q-learning
self.Q[obs, action] += self.lr * (target_Q - predict_Q) # 修正q
# 把 Q表格 的数据保存到文件中
def save(self):
npy_file = './q_table.npy'
np.save(npy_file, self.Q)
print(npy_file + ' saved.')
# 从文件中读取数据到 Q表格
def restore(self, npy_file='./q_table.npy'):
self.Q = np.load(npy_file)
print(npy_file + ' loaded.')
def run_episode(env, agent, render=False):
total_steps = 0 # 记录每个episode走了多少step
total_reward = 0
obs = env.reset() # 重置环境, 重新开一局(即开始新的一个episode)
while True:
action = agent.sample(obs) # 根据算法选择一个动作
next_obs, reward, done, _ = env.step(action) # 与环境进行一个交互
# 训练 Q-learning算法
agent.learn(obs, action, reward, next_obs, done)
obs = next_obs # 存储上一个观察值
total_reward += reward
total_steps += 1 # 计算step数
if render:
env.render() #渲染新的一帧图形
if done:
break
return total_reward, total_steps
def test_episode(env, agent):
total_reward = 0
obs = env.reset()
while True:
action = agent.predict(obs) # greedy
next_obs, reward, done, _ = env.step(action)
total_reward += reward
obs = next_obs
time.sleep(0.5)
env.render()
if done:
print('test reward = %.1f' % (total_reward))
break
def main():
# env = gym.make("FrozenLake-v0", is_slippery=False) # 0 left, 1 down, 2 right, 3 up
# env = FrozenLakeWapper(env)
env = gym.make("CliffWalking-v0") # 0 up, 1 right, 2 down, 3 left
env = CliffWalkingWapper(env)
agent = QLearningAgent(
obs_n=env.observation_space.n,
act_n=env.action_space.n,
learning_rate=0.1,
gamma=0.9,
e_greed=0.1)
is_render = False
for episode in range(500):
ep_reward, ep_steps = run_episode(env, agent, is_render)
print('Episode %s: steps = %s , reward = %.1f' % (episode, ep_steps,
ep_reward))
# 每隔20个episode渲染一下看看效果
if episode % 20 == 0:
is_render = True
else:
is_render = False
# 训练结束,查看算法效果
test_episode(env, agent)
if __name__ == "__main__":
main()
\ No newline at end of file
++ "b/2-\345\272\224\347\224\250/2.13-\346\267\261\345\272\246\345\274\272\345\214\226\345\255\246\344\271\240/2.13.2-Policy-based Learning\347\275\221\347\273\234\350\257\246\350\247\243\344\270\216\344\273\243\347\240\201/.gitkeep"
import gym
import gym
import numpy as np
import torch
import matplotlib.pyplot as plt
# CartPole-v0
# Acrobot-v1
# MountainCar-v0
env = gym.make("MountainCar-v0")
l1 = env.observation_space.shape[0] # 输入数据长度为4
l2 = 150 # 隐藏层为150
l3 = env.action_space.n # 输出是一个用于向左向右动作长度为2的向量
model = torch.nn.Sequential(
torch.nn.Linear(l1, l2),
torch.nn.LeakyReLU(), # leakyReLU的意思是,不用relu激活,可以自己去掉试一试,效果会变差。
torch.nn.Linear(l2, l3),
torch.nn.Softmax(dim=0) # 输出是一个动作的softmax概率分布
)
learning_rate = 0.009
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# state1 = env.reset()
# pred = model(torch.from_numpy(state1).float()) # 调用策略网络模型产生预测的动作概率
# action = np.random.choice(np.array([0, 1]), p=pred.data.numpy()) # 从策略网络产生的概率分布中抽样一个动作
# state2, reward, done, info = env.step(action) # 采取动作并获得新的状态和奖励。info变量由环境产生,但与环境无关
def discount_rewards(rewards, gamma=0.99):
lenr = len(rewards)
disc_return = torch.pow(gamma, torch.arange(lenr).float()) * rewards # 计算指数衰减奖励
disc_return /= disc_return.max() # 讲奖励归一化到[0,1]以提高数值稳定性
return disc_return
def loss_fn(preds, r): # 损失函数期望一个对所采取动作的动作概率数组和贴现奖励
return -1 * torch.sum(r * torch.log(preds)) # 用于计算概率的对数,乘损失奖励,对其求和,然后对结果取反
MAX_DUR = 200
MAX_EPISODES = 500
gamma = 0.99
score = [] # 记录训练期间轮次长度的列表
expectation = 0.0
for episode in range(MAX_EPISODES):
curr_state = env.reset()
# env.render()
done = False
transitions = [] # 一系列状态,动作,奖励(但是我们忽略奖励)
for t in range(MAX_DUR):
act_prob = model(torch.from_numpy(curr_state).float()) # 获取动作概率
action = np.random.choice(np.arange(l3), p=act_prob.data.numpy()) # 随机选取一个动作
prev_state = curr_state
curr_state, _, done, info = env.step(action) # 在环境中采取动作
transitions.append((prev_state, action, t + 1)) # 存储这个转换
if done: # 游戏失败则退出循环
break
ep_len = len(transitions)
score.append(ep_len) # 存储轮次时长
print(ep_len)
reward_batch = torch.Tensor([r for (s, a, r) in transitions]).flip(dims=(0,)) # 在单个张量中收集轮次的所有奖励
disc_returns = discount_rewards(reward_batch) # 计算衰减奖励
state_batch = torch.Tensor([s for (s, a, r) in transitions]) # 在单个张量中收集轮次中的状态
action_batch = torch.Tensor([a for (s, a, r) in transitions]) # 在单个张量中收集轮次中的动作
pred_batch = model(state_batch) # 重新计算轮次中所有状态的动作概率
prob_batch = pred_batch.gather(dim=1, index=action_batch.long().view(-1, 1)).squeeze() # 取与实际采取动作关联的动作概率的子集
loss = loss_fn(prob_batch, disc_returns)
optimizer.zero_grad()
loss.backward()
optimizer.step()
env.render()
score = np.array(score)
# avg_score = running_mean(score, 50)
plt.figure(figsize=(10, 7))
plt.ylabel("Episode Duration", fontsize=22)
plt.xlabel("Training Epochs", fontsize=22)
plt.plot(score, color='green')
plt.show()
\ No newline at end of file
# Auto detect text files and perform LF normalization
# Auto detect text files and perform LF normalization
* text=auto
# Default ignored files
# Default ignored files
/shelf/
/workspace.xml
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml
# Editor-based HTTP Client requests
/httpRequests/
<?xml version="1.0" encoding="UTF-8"?>
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="TestRunnerService">
<option name="PROJECT_TEST_RUNNER" value="pytest" />
</component>
</module>
\ No newline at end of file
<component name="InspectionProjectProfileManager">
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="JavaScriptSettings">
<option name="languageLevel" value="ES6" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8" project-jdk-type="Python SDK" />
</project>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/DDPG-main.iml" filepath="$PROJECT_DIR$/.idea/DDPG-main.iml" />
</modules>
</component>
</project>
\ No newline at end of file
import torch as T
import torch as T
import torch.nn.functional as F
import numpy as np
from networks import ActorNetwork, CriticNetwork
from buffer import ReplayBuffer
device = T.device("cuda:0" if T.cuda.is_available() else "cpu")
class DDPG:
def __init__(self, alpha, beta, state_dim, action_dim, actor_fc1_dim,
actor_fc2_dim, critic_fc1_dim, critic_fc2_dim, ckpt_dir,
gamma=0.99, tau=0.005, action_noise=0.1, max_size=1000000,
batch_size=256):
self.gamma = gamma
self.tau = tau
self.action_noise = action_noise
self.checkpoint_dir = ckpt_dir
self.actor = ActorNetwork(alpha=alpha, state_dim=state_dim, action_dim=action_dim,
fc1_dim=actor_fc1_dim, fc2_dim=actor_fc2_dim)
self.target_actor = ActorNetwork(alpha=alpha, state_dim=state_dim, action_dim=action_dim,
fc1_dim=actor_fc1_dim, fc2_dim=actor_fc2_dim)
self.critic = CriticNetwork(beta=beta, state_dim=state_dim, action_dim=action_dim,
fc1_dim=critic_fc1_dim, fc2_dim=critic_fc2_dim)
self.target_critic = CriticNetwork(beta=beta, state_dim=state_dim, action_dim=action_dim,
fc1_dim=critic_fc1_dim, fc2_dim=critic_fc2_dim)
self.memory = ReplayBuffer(max_size=max_size, state_dim=state_dim, action_dim=action_dim,
batch_size=batch_size)
self.update_network_parameters(tau=1.0)
def update_network_parameters(self, tau=None):
if tau is None:
tau = self.tau
for actor_params, target_actor_params in zip(self.actor.parameters(),
self.target_actor.parameters()):
target_actor_params.data.copy_(tau * actor_params + (1 - tau) * target_actor_params)
for critic_params, target_critic_params in zip(self.critic.parameters(),
self.target_critic.parameters()):
target_critic_params.data.copy_(tau * critic_params + (1 - tau) * target_critic_params)
def remember(self, state, action, reward, state_, done):
self.memory.store_transition(state, action, reward, state_, done)
def choose_action(self, observation, train=True):
self.actor.eval()
state = T.tensor([observation], dtype=T.float).to(device)
action = self.actor.forward(state).squeeze()
if train:
noise = T.tensor(np.random.normal(loc=0.0, scale=self.action_noise),
dtype=T.float).to(device)
action = T.clamp(action+noise, -1, 1)
self.actor.train()
return action.detach().cpu().numpy()
def learn(self):
if not self.memory.ready():
return
states, actions, reward, states_, terminals = self.memory.sample_buffer()
states_tensor = T.tensor(states, dtype=T.float).to(device)
actions_tensor = T.tensor(actions, dtype=T.float).to(device)
rewards_tensor = T.tensor(reward, dtype=T.float).to(device)
next_states_tensor = T.tensor(states_, dtype=T.float).to(device)
terminals_tensor = T.tensor(terminals).to(device)
with T.no_grad():
next_actions_tensor = self.target_actor.forward(next_states_tensor)
q_ = self.target_critic.forward(next_states_tensor, next_actions_tensor).view(-1)
q_[terminals_tensor] = 0.0
target = rewards_tensor + self.gamma * q_
q = self.critic.forward(states_tensor, actions_tensor).view(-1)
critic_loss = F.mse_loss(q, target.detach())
self.critic.optimizer.zero_grad()
critic_loss.backward()
self.critic.optimizer.step()
new_actions_tensor = self.actor.forward(states_tensor)
actor_loss = -T.mean(self.critic(states_tensor, new_actions_tensor))
self.actor.optimizer.zero_grad()
actor_loss.backward()
self.actor.optimizer.step()
self.update_network_parameters()
def save_models(self, episode):
self.actor.save_checkpoint(self.checkpoint_dir + 'Actor/DDPG_actor_{}.pth'.format(episode))
print('Saving actor network successfully!')
self.target_actor.save_checkpoint(self.checkpoint_dir +
'Target_actor/DDPG_target_actor_{}.pth'.format(episode))
print('Saving target_actor network successfully!')
self.critic.save_checkpoint(self.checkpoint_dir + 'Critic/DDPG_critic_{}'.format(episode))
print('Saving critic network successfully!')
self.target_critic.save_checkpoint(self.checkpoint_dir +
'Target_critic/DDPG_target_critic_{}'.format(episode))
print('Saving target critic network successfully!')
def load_models(self, episode):
self.actor.load_checkpoint(self.checkpoint_dir + 'Actor/DDPG_actor_{}.pth'.format(episode))
print('Loading actor network successfully!')
self.target_actor.load_checkpoint(self.checkpoint_dir +
'Target_actor/DDPG_target_actor_{}.pth'.format(episode))
print('Loading target_actor network successfully!')
self.critic.load_checkpoint(self.checkpoint_dir + 'Critic/DDPG_critic_{}'.format(episode))
print('Loading critic network successfully!')
self.target_critic.load_checkpoint(self.checkpoint_dir +
'Target_critic/DDPG_target_critic_{}'.format(episode))
print('Loading target critic network successfully!')
MIT License
MIT License
Copyright (c) 2022 indigoLovee
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
# DDPG
# DDPG
DDPG in Pytorch
# 仿真环境
gym中的LunarLanderContinuous-v2
# 环境依赖
* gym
* numpy
* matplotlib
* python3.6
* pytorch1.6
# 文件描述
* train.py为训练脚本,配置好环境后直接运行即可,不过需要在当前目录下创建output_images文件夹,用于放置生成的仿真结果;
* network.py为网络脚本,包括演员网络和评论家网络;
* buffer.py为经验回放池脚本;
* DDPG.py为DDPG算法的实现脚本;
* utils.py为工具箱脚本,里面主要放置一些通过函数;
* test.py为测试脚本,通过加载训练好的权重在环境中进行测试,测试训练效果。
# 仿真结果
详见output_images文件夹
import numpy as np
import numpy as np
class ReplayBuffer:
def __init__(self, max_size, state_dim, action_dim, batch_size):
self.mem_size = max_size
self.batch_size = batch_size
self.mem_cnt = 0
self.state_memory = np.zeros((self.mem_size, state_dim))
self.action_memory = np.zeros((self.mem_size, action_dim))
self.reward_memory = np.zeros((self.mem_size, ))
self.next_state_memory = np.zeros((self.mem_size, state_dim))
self.terminal_memory = np.zeros((self.mem_size, ), dtype=np.bool_)
def store_transition(self, state, action, reward, state_, done):
mem_idx = self.mem_cnt % self.mem_size
self.state_memory[mem_idx] = state
self.action_memory[mem_idx] = action
self.reward_memory[mem_idx] = reward
self.next_state_memory[mem_idx] = state_
self.terminal_memory[mem_idx] = done
self.mem_cnt += 1
def sample_buffer(self):
mem_len = min(self.mem_size, self.mem_cnt)
batch = np.random.choice(mem_len, self.batch_size, replace=False)
states = self.state_memory[batch]
actions = self.action_memory[batch]
rewards = self.reward_memory[batch]
states_ = self.next_state_memory[batch]
terminals = self.terminal_memory[batch]
return states, actions, rewards, states_, terminals
def ready(self):
return self.mem_cnt >= self.batch_size
import torch as T
import torch as T
import torch.nn as nn
import torch.optim as optim
device = T.device("cuda:0" if T.cuda.is_available() else "cpu")
def weight_init(m):
if isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0.0)
elif isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1.0)
nn.init.constant_(m.bias, 0.0)
class ActorNetwork(nn.Module):
def __init__(self, alpha, state_dim, action_dim, fc1_dim, fc2_dim):
super(ActorNetwork, self).__init__()
self.fc1 = nn.Linear(state_dim, fc1_dim)
self.ln1 = nn.LayerNorm(fc1_dim)
self.fc2 = nn.Linear(fc1_dim, fc2_dim)
self.ln2 = nn.LayerNorm(fc2_dim)
self.action = nn.Linear(fc2_dim, action_dim)
self.optimizer = optim.Adam(self.parameters(), lr=alpha)
self.apply(weight_init)
self.to(device)
def forward(self, state):
x = T.relu(self.ln1(self.fc1(state)))
x = T.relu(self.ln2(self.fc2(x)))
action = T.tanh(self.action(x))
return action
def save_checkpoint(self, checkpoint_file):
T.save(self.state_dict(), checkpoint_file)
def load_checkpoint(self, checkpoint_file):
self.load_state_dict(T.load(checkpoint_file))
class CriticNetwork(nn.Module):
def __init__(self, beta, state_dim, action_dim, fc1_dim, fc2_dim):
super(CriticNetwork, self).__init__()
self.fc1 = nn.Linear(state_dim, fc1_dim)
self.ln1 = nn.LayerNorm(fc1_dim)
self.fc2 = nn.Linear(fc1_dim, fc2_dim)
self.ln2 = nn.LayerNorm(fc2_dim)
self.fc3 = nn.Linear(action_dim, fc2_dim)
self.q = nn.Linear(fc2_dim, 1)
self.optimizer = optim.Adam(self.parameters(), lr=beta, weight_decay=0.001)
self.apply(weight_init)
self.to(device)
def forward(self, state, action):
x_s = T.relu(self.ln1(self.fc1(state)))
x_s = self.ln2(self.fc2(x_s))
x_a = self.fc3(action)
x = T.relu(x_s + x_a)
q = self.q(x)
return q
def save_checkpoint(self, checkpoint_file):
T.save(self.state_dict(), checkpoint_file)
def load_checkpoint(self, checkpoint_file):
self.load_state_dict(T.load(checkpoint_file))
import gym
import gym
import imageio
import argparse
from DDPG import DDPG
from utils import scale_action
parser = argparse.ArgumentParser()
parser.add_argument('--filename', type=str, default='./output_images/LunarLander.gif')
parser.add_argument('--checkpoint_dir', type=str, default='./checkpoints/DDPG/')
parser.add_argument('--save_video', type=bool, default=True)
parser.add_argument('--fps', type=int, default=30)
parser.add_argument('--render', type=bool, default=True)
args = parser.parse_args()
def main():
env = gym.make('LunarLanderContinuous-v2')
agent = DDPG(alpha=0.0003, beta=0.0003, state_dim=env.observation_space.shape[0],
action_dim=env.action_space.shape[0], actor_fc1_dim=400, actor_fc2_dim=300,
critic_fc1_dim=400, critic_fc2_dim=300, ckpt_dir=args.checkpoint_dir,
batch_size=256)
agent.load_models(1000)
video = imageio.get_writer(args.filename, fps=args.fps)
done = False
observation = env.reset()
while not done:
if args.render:
env.render()
action = agent.choose_action(observation, train=True)
action_ = scale_action(action.copy(), env.action_space.high, env.action_space.low)
observation_, reward, done, info = env.step(action_)
observation = observation_
if args.save_video:
video.append_data(env.render(mode='rgb_array'))
if __name__ == '__main__':
main()
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment