Commit 471d171f by 靓靓

upload files

parent 18e947f5
++ "b/6-\346\250\241\345\236\213\350\256\255\347\273\203/.gitkeep"
# 课上补充资料:
# 课上补充资料:
https://github.com/jindongwang/transferlearning
https://github.com/PacktPublishing/Ensemble-Machine-Learning/tree/master
{
{
"cells": [
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"training epoch: 0\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 469/469 [00:19<00:00, 23.49it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"training epoch: 1\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 469/469 [00:19<00:00, 23.56it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"training epoch: 2\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 469/469 [00:19<00:00, 23.55it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"training epoch: 3\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 469/469 [00:19<00:00, 23.47it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"training epoch: 4\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 469/469 [00:20<00:00, 23.31it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"training epoch: 5\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 469/469 [00:20<00:00, 22.63it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"training epoch: 6\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 469/469 [00:20<00:00, 22.93it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"training epoch: 7\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 469/469 [00:20<00:00, 23.11it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"training epoch: 8\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 469/469 [00:20<00:00, 23.12it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"training epoch: 9\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 469/469 [00:20<00:00, 23.14it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"training epoch: 10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 469/469 [00:20<00:00, 22.91it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"training epoch: 11\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 469/469 [00:20<00:00, 22.70it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"training epoch: 12\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 469/469 [00:20<00:00, 22.87it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"training epoch: 13\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 469/469 [00:19<00:00, 23.51it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"training epoch: 14\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 469/469 [00:20<00:00, 23.15it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"training epoch: 15\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 469/469 [00:20<00:00, 22.94it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"training epoch: 16\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 469/469 [00:20<00:00, 23.05it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"training epoch: 17\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 469/469 [00:20<00:00, 23.10it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"training epoch: 18\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 469/469 [00:20<00:00, 22.40it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"training epoch: 19\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 469/469 [00:20<00:00, 22.91it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"training epoch: 20\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 469/469 [00:20<00:00, 23.09it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"training epoch: 21\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 469/469 [00:20<00:00, 22.98it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"training epoch: 22\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 469/469 [00:20<00:00, 23.05it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"training epoch: 23\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 469/469 [00:20<00:00, 22.95it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"training epoch: 24\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 469/469 [00:20<00:00, 22.93it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"training epoch: 25\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 469/469 [00:20<00:00, 22.56it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"training epoch: 26\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 469/469 [00:20<00:00, 22.65it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"training epoch: 27\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 469/469 [00:20<00:00, 23.12it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"training epoch: 28\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 469/469 [00:20<00:00, 23.05it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"training epoch: 29\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 469/469 [00:20<00:00, 23.01it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"模型保存成功\n",
"生成模型重载成功\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"import torch\n",
"import torchvision\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"from torch.utils.data import DataLoader\n",
"from torchvision import datasets\n",
"from torch.optim import Adam\n",
"import argparse\n",
"from tqdm import tqdm\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
" \n",
"def download():\n",
" # 将图片转化为张量以及归一化处理\n",
" Trans = torchvision.transforms.Compose(\n",
" [torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=[0.5], std=[0.5])])\n",
" \n",
" # 下载MNIST对应的训练和测试数据集\n",
" train_data = datasets.FashionMNIST(\n",
" root=\"data\",\n",
" train=True,\n",
" download=True,\n",
" transform=Trans,\n",
" )\n",
" \n",
" test_data = datasets.FashionMNIST(\n",
" root=\"data\",\n",
" train=False,\n",
" download=True,\n",
" transform=Trans,\n",
" )\n",
" \n",
" train_Dataloader = DataLoader(train_data,batch_size=128)\n",
" test_Dataloader = DataLoader(test_data,batch_size=999999)\n",
" \n",
" return train_Dataloader, test_Dataloader, train_data, test_data\n",
" \n",
" \n",
"class Discriminator(nn.Module):\n",
" def __init__(self):\n",
" super(Discriminator, self).__init__()\n",
" self.judge = nn.Sequential(nn.Linear(28*28,512), nn.ReLU(), nn.Linear(512,256), nn.ReLU(), nn.Linear(256,32), nn.ReLU(), nn.Linear(32,1), nn.Sigmoid())\n",
" \n",
" def forward(self,image):\n",
" y = self.judge(image)\n",
" return y\n",
" \n",
" \n",
" \n",
"class Generator(nn.Module):\n",
" def __init__(self):\n",
" super(Generator, self).__init__()\n",
" self.generate = nn.Sequential(nn.Linear(100,256), nn.ReLU(), nn.Linear(256,512), nn.ReLU(), nn.Linear(512,28*28))\n",
" \n",
" def forward(self, x):\n",
" image = self.generate(x)\n",
" return image\n",
" \n",
" \n",
"def train(descriminator, generator, d_optimizer, g_optimizer, train_dataloader, loss_function):\n",
" for real_image,_ in tqdm(train_dataloader):\n",
" real_image = real_image.to('cuda')\n",
" real_image = real_image.reshape(-1,28*28)\n",
" \n",
" # 先看判别器损失\n",
" real_label = descriminator(real_image)\n",
" d_loss_real = loss_function(real_label, torch.ones_like(real_label))\n",
" \n",
" random_tensor = torch.randn(real_image.size(0),100).to('cuda')\n",
" fake_image = generator(random_tensor)\n",
" fake_label = descriminator(fake_image.detach())\n",
" d_loss_fake = loss_function(fake_label, torch.zeros_like(fake_label))\n",
" \n",
" d_loss = d_loss_real + d_loss_fake\n",
" d_optimizer.zero_grad()\n",
" d_loss.backward()\n",
" d_optimizer.step()\n",
" \n",
" \n",
" # 生成器损失\n",
" fake_label = descriminator(fake_image)\n",
" g_loss = loss_function(fake_label, torch.ones_like(fake_label))\n",
" \n",
" g_optimizer.zero_grad()\n",
" g_loss.backward()\n",
" g_optimizer.step()\n",
" \n",
" \n",
" \n",
"if __name__ == \"__main__\":\n",
" train_dataloader, test_dataloader, train_data, test_data = download()\n",
" \n",
" device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
" \n",
" descriminator = Discriminator().to(device)\n",
" generator = Generator().to(device)\n",
" \n",
" loss_function = nn.BCELoss()\n",
" \n",
" d_optimizer = Adam(descriminator.parameters(), lr=0.001)\n",
" g_optimizer = Adam(generator.parameters(), lr=0.001)\n",
" \n",
" epochs = 30\n",
" for epoch in range(epochs):\n",
" print(\"training epoch:\",epoch)\n",
" train(descriminator, generator, d_optimizer, g_optimizer, train_dataloader, loss_function)\n",
" \n",
" \n",
" torch.save(generator.state_dict(),'./generator.pth')\n",
" torch.save(descriminator.state_dict(),'./descriminator.pth')\n",
" print(\"模型保存成功\")\n",
" \n",
" new_generator = Generator()\n",
" new_generator.load_state_dict(torch.load('./generator.pth'))\n",
" print(\"生成模型重载成功\")\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"生成模型重载成功\n"
]
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1000x1000 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import torch\n",
"\n",
"import torchvision\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
" \n",
"new_generator = Generator()\n",
"new_generator.load_state_dict(torch.load('./generator.pth'))\n",
"print(\"生成模型重载成功\")\n",
" \n",
"with torch.no_grad():\n",
" x = torch.randn(16, 100)\n",
" fake = new_generator(x)\n",
" \n",
" fake = fake.reshape(-1, 1, 28, 28)\n",
" img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)\n",
" \n",
" # 将图像张量转换为 NumPy 数组\n",
" img_grid_fake_np = img_grid_fake.cpu().numpy()\n",
" img_grid_fake_np = np.transpose(img_grid_fake_np, (1, 2, 0))\n",
" \n",
" # 使用 matplotlib 显示图像\n",
" plt.figure(figsize=(10, 10))\n",
" plt.imshow(img_grid_fake_np)\n",
" plt.axis('off') # 不显示坐标轴\n",
" plt.show()\n",
" "
]
}
],
"metadata": {
"kernelspec": {
"display_name": "geognn",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.19"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
38.3,5,3.4,37.3,72.1,11.9,16,500,10,80
38.3,5,3.4,37.3,72.1,11.9,16,500,10,80
41.7,5.8,3,22.7,64.6,8.7,26,800,20,40
38.3,5,3.4,37.3,72.1,11.9,16,400,10,70
35.7,5.2,3.5,25.4,60.7,4.6,29.5,650,60,10.83333333
20.3,2.7,3.4,27,50.3,3.9,45.8,600,20,90
28,3.4,4,10.3,38,7.9,46.8,575,28.75,20
23.97,3.59,4.27,5.33,37.93,4.03,58.04,650,13,50
41.7,5.8,3,22.7,64.6,8.7,26,700,20,35
34.375,5.35,36.97,2.32,73.775,5.2475,36.94709524,500,15,120
26.638,4.733,17.513,4.355,53.9876699,7.92815534,43.761,500,10,120
41.185,5.59,40.61,1.94,81.605,7.7225,36.94709524,600,15,120
38.3,5,3.4,37.3,72.1,11.9,16,300,10,60
35.5,5.5,5.1,27.2,63,12.1,24.9,700,10,60
30.2,4.3,4.8,21.4,48.8,3.1,38.1,350,35,10
41.7,5.8,3,22.7,64.6,8.7,26,900,20,45
38.1,5,5.8,15,64.8,5.3,29.9,450,11.25,40
37.78,5.47,38.79,2.13,77.69,6.485,36.94709524,700,15,120
35.7,5.2,3.5,25.4,60.7,4.6,29.5,450,5,90
41.185,5.59,40.61,1.94,81.605,7.7225,36.94709524,400,15,120
33.29,5.06,5.14,16.91,52.08,6.26,35.74,550,12.2,45
20.3,2.7,3.4,27,50.3,3.9,45.8,500,20,85
23.97,3.59,4.27,5.33,37.93,4.03,58.04,550,11,50
24.67,4.65,4.51,18.6,51.44,5.68,46.62,500,15,40
34.375,5.35,36.97,2.32,73.775,5.2475,36.94709524,700,15,120
35.7,5.2,3.5,25.4,60.7,4.6,29.5,450,60,7.5
35.8975,5.1575,30.01,3.6525,53.9876699,7.92815534,25.275,500,10,120
26.52,6.24,4.08,20.24,40.3,8.7,42,700,10,60
35.5,5.5,5.1,27.2,63,12.1,24.9,600,10,60
41.185,5.59,40.61,1.94,81.605,7.7225,36.94709524,500,15,120
41.26,5.425,31.3,3.195,53.9876699,7.92815534,18.82,500,10,120
22.3,3.6,2.8,16.6,41,5.7,53.3,500,10,60
46.615,5.6925,32.59,2.7375,53.9876699,7.92815534,12.365,500,10,120
37.78,5.47,38.79,2.13,77.69,6.485,36.94709524,400,15,120
9.1,0.86,1.5,10.2,18.8,2.8,78.4,700,20,95
52.3,8,6.7,32.3,62.3,6.5,31.2,1040,74.3,24
41.7,5.8,3,22.7,64.6,8.7,26,600,20,30
41.61,5.19,5.61,26.01,64.15,15.57,20.28,600,20,30
41.7,5.8,3,22.7,64.6,8.7,26,500,20,25
21.5,3.3,3.6,16.8,43.3,1.9,54.8,600,20,90
21.5,3.3,3.6,16.8,43.3,1.9,54.8,500,20,85
23.97,3.59,4.27,5.33,37.93,4.03,58.04,450,9,50
20.3,2.7,3.4,27,50.3,3.9,45.8,400,20,80
32.074,4.979,22.859,3.565,53.9876699,7.92815534,36.523,500,10,120
41.7,5.8,3,22.7,64.6,8.7,26,400,20,20
37.78,5.47,38.79,2.13,77.69,6.485,36.94709524,600,15,120
17.6,2.5,3.1,16.1,34.1,5.9,60,500,20,85
34.375,5.35,36.97,2.32,73.775,5.2475,36.94709524,400,15,120
9.1,0.86,1.5,10.2,18.8,2.8,78.4,600,20,90
17.6,2.5,3.1,16.1,34.1,5.9,60,400,20,80
33.292,5.102,25.532,3.17,53.9876699,7.92815534,32.904,500,10,120
26.52,6.24,4.08,20.24,40.3,8.7,42,400,10,60
35.5,5.5,5.1,27.2,63,12.1,24.9,500,10,60
24.67,4.65,4.51,18.6,51.44,5.68,46.62,300,15,40
34.375,5.35,36.97,2.32,73.775,5.2475,36.94709524,600,15,120
30.2,4.3,4.8,21.4,48.8,3.1,38.1,450,45,10
22.3,3.6,2.8,16.6,41,5.7,53.3,600,10,60
47.07,3.91,2.29,44.4,70.4,6.21,11.84,600,25,25
25.5,4.5,4.9,25.9,54.2,8.6,37.2,500,50,10
12.79,1.74,1.2,16.22,29.01,3.49,67.5,600,6.666666667,90
17.6,2.5,3.1,16.1,34.1,5.9,60,300,20,75
34.51,5.225,28.205,2.775,53.9876699,7.92815534,29.285,500,10,120
21.5,3.3,3.6,16.8,43.3,1.9,54.8,300,20,75
38.5,5.08,1.33,39.49,62.85,22.74,14.4,500,5.5,30
38.5,5.08,1.33,39.49,62.85,22.74,14.4,600,5.5,30
29.63,5.3,5.11,24.41,60.89,5.04,34.27,300,10,30
20.3,2.7,3.4,27,50.3,3.9,45.8,300,20,75
29.63,5.3,5.11,24.41,60.89,5.04,34.27,350,10,30
35.7,5.2,3.5,25.4,60.7,4.6,29.5,850,60,14.16666667
29.63,5.3,5.11,24.41,60.89,5.04,34.27,500,10,30
24.9,3.36,0.21,44.31,63.5,9.72,26.78,800,10,30
46.615,5.6925,32.59,2.7375,53.9876699,7.92815534,12.365,300,10,120
41.26,5.425,31.3,3.195,53.9876699,7.92815534,18.82,300,10,120
41.7,5.8,3,22.7,64.6,8.7,26,300,20,15
21.5,3.3,3.6,16.8,43.3,1.9,54.8,400,20,80
35.8975,5.1575,30.01,3.6525,53.9876699,7.92815534,25.275,300,10,120
9.1,0.86,1.5,10.2,18.8,2.8,78.4,500,20,85
12.18,5.82,1.26,23.06,27.12,16.65,56.23,900,35,25.71
29.63,5.3,5.11,24.41,60.89,5.04,34.27,450,10,30
29.63,5.3,5.11,24.41,60.89,5.04,34.27,400,10,30
9.1,0.86,1.5,10.2,18.8,2.8,78.4,400,20,80
24.67,4.65,4.51,18.6,51.44,5.68,46.62,700,15,40
47.07,3.91,2.29,44.4,70.4,6.21,11.84,400,25,25
38.5,5.08,1.33,39.49,62.85,22.74,14.4,400,5.5,30
9.1,0.86,1.5,10.2,18.8,2.8,78.4,300,20,75
17.6,2.5,3.1,16.1,34.1,5.9,60,700,20,95
37.78,5.47,38.79,2.13,77.69,6.485,36.94709524,500,15,120
35.8975,5.1575,30.01,3.6525,53.9876699,7.92815534,25.275,700,10,120
12.18,5.82,1.26,23.06,27.12,16.65,56.23,700,35,20
42.1,6.1,6.4,27.3,63.5,11.6,16.6,450,20,52.5
42.1,6.1,6.4,27.3,63.5,11.6,16.6,525,20,56.25
42.1,6.1,6.4,27.3,63.5,11.6,16.6,600,20,60
23.97,3.59,4.27,5.33,37.93,4.03,58.04,750,15,50
20.3,2.7,3.4,27,50.3,3.9,45.8,700,20,95
42.1,6.1,6.4,27.3,63.5,11.6,16.6,375,20,48.75
42.1,6.1,6.4,27.3,63.5,11.6,16.6,600,30,50
41.26,5.425,31.3,3.195,53.9876699,7.92815534,18.82,700,10,120
42.1,6.1,6.4,27.3,63.5,11.6,16.6,450,30,45
42.1,6.1,6.4,27.3,63.5,11.6,16.6,525,30,47.5
36.52,5.33,5.18,23.15,65.07,7.65,27.69,700,10,60
21.5,3.3,3.6,16.8,43.3,1.9,54.8,700,20,95
12.18,5.82,1.26,23.06,27.12,16.65,56.23,500,35,14.28
30.856,4.856,20.186,3.96,53.9876699,7.92815534,40.142,500,10,120
46.615,5.6925,32.59,2.7375,53.9876699,7.92815534,12.365,700,10,120
30.2,4.3,4.8,21.4,48.8,3.1,38.1,400,40,10
41.185,5.59,40.61,1.94,81.605,7.7225,36.94709524,700,15,120
17.6,2.5,3.1,16.1,34.1,5.9,60,600,20,90
42.1,6.1,6.4,27.3,63.5,11.6,16.6,300,10,60
42.1,6.1,6.4,27.3,63.5,11.6,16.6,375,10,67.5
42.1,6.1,6.4,27.3,63.5,11.6,16.6,450,10,75
42.1,6.1,6.4,27.3,63.5,11.6,16.6,525,10,82.5
30.54,2.2,1.44,8.05,23.66,19.36,56.98,850,20.73170732,41
30.54,2.2,1.44,8.05,23.66,19.36,56.98,650,15.85365854,41
42.1,6.1,6.4,27.3,63.5,11.6,16.6,600,10,90
30.54,2.2,1.44,8.05,23.66,19.36,56.98,450,10.97560976,41
42.1,6.1,6.4,27.3,63.5,11.6,16.6,300,20,45
23.97,3.59,4.27,3.53,37.93,4.03,58.04,450,10.97560976,41
23.97,3.59,4.27,3.53,37.93,4.03,58.04,650,15.85365854,41
0.1
0.1
0.1
0.1
0.105308219
0.107142857
0.11
0.11
0.11
0.110607434
0.111058601
0.113930267
0.12
0.12
0.13
0.13
0.130239521
0.132173192
0.133333333
0.141719217
0.158350515
0.158505155
0.16
0.16
0.161218092
0.163043478
0.169734513
0.17
0.17
0.174592617
0.179504814
0.18
0.182315668
0.188382412
0.197368421
0.2
0.2
0.219606579
0.22
0.220588235
0.224522293
0.24
0.241189427
0.245472837
0.25
0.250607198
0.263157895
0.265535313
0.27
0.270220588
0.271174377
0.28
0.28
0.29
0.296411856
0.3
0.3
0.3
0.3
0.3
0.315088757
0.318587106
0.32
0.321896
0.328737
0.33
0.331395349
0.36
0.380067568
0.39
0.39
0.390319258
0.395809611
0.41
0.42
0.428995253
0.434210526
0.44
0.46
0.47
0.477272727
0.5
0.5
0.513459
0.564705882
0.6
0.632675847
0.676392573
0.686063218
0.7
0.7
0.7
0.7
0.708661417
0.71
0.74
0.753336203
0.77
0.78
0.8
0.803571429
0.831709477
0.850354314
0.855880729
0.9
0.972177806
0.975
1.29
1.3
1.32
1.36
1.454914722
1.481535649
1.5
1.550840203
1.75
2.415956014
2.948008277
Inputs,,,,,,,,,,Outputs
Inputs,,,,,,,,,,Outputs
Ultimate analysis (sewage sludge),,Proximate analysis (sewage sludge),,,,,Operating conditions,,,Bio-char
C (wt%),H (wt%),N (wt%),O (wt%),Volatile matter (wt%),Fixed carbon (wt%),Ash (wt%),Temperature (°C),Heating rate (°C/min),Reaction time (min),OC ratio
38.3,5,3.4,37.3,72.1,11.9,16,500,10,80,0.1
41.7,5.8,3,22.7,64.6,8.7,26,800,20,40,0.1
38.3,5,3.4,37.3,72.1,11.9,16,400,10,70,0.1
35.7,5.2,3.5,25.4,60.7,4.6,29.5,650,60,10.83333333,0.105308219
20.3,2.7,3.4,27,50.3,3.9,45.8,600,20,90,0.107142857
28,3.4,4,10.3,38,7.9,46.8,575,28.75,20,0.11
23.97,3.59,4.27,5.33,37.93,4.03,58.04,650,13,50,0.11
41.7,5.8,3,22.7,64.6,8.7,26,700,20,35,0.11
34.375,5.35,26.97,2.32,73.775,5.2475,36.94709524,500,15,120,0.110607434
26.638,4.733,17.513,4.355,53.9876699,7.92815534,43.761,500,10,120,0.111058601
41.185,5.59,30.61,1.94,81.605,7.7225,36.94709524,600,15,120,0.113930267
38.3,5,3.4,37.3,72.1,11.9,16,300,10,60,0.12
35.5,5.5,5.1,27.2,63,12.1,24.9,700,10,60,0.12
30.2,4.3,4.8,21.4,48.8,3.1,38.1,350,35,10,0.13
41.7,5.8,3,22.7,64.6,8.7,26,900,20,45,0.13
38.1,5,5.8,15,64.8,5.3,29.9,450,11.25,40,0.130239521
37.78,5.47,28.79,2.13,77.69,6.485,36.94709524,700,15,120,0.132173192
35.7,5.2,3.5,25.4,60.7,4.6,29.5,450,5,90,0.133333333
41.185,5.59,30.61,1.94,81.605,7.7225,36.94709524,400,15,120,0.141719217
33.29,5.06,5.14,16.91,52.08,6.26,35.74,550,12.2,45,0.158350515
20.3,2.7,3.4,27,50.3,3.9,45.8,500,20,85,0.158505155
23.97,3.59,4.27,5.33,37.93,4.03,58.04,550,11,50,0.16
24.67,4.65,4.51,18.6,51.44,5.68,46.62,500,15,40,0.16
34.375,5.35,26.97,2.32,73.775,5.2475,36.94709524,700,15,120,0.161218092
35.7,5.2,3.5,25.4,60.7,4.6,29.5,450,60,7.5,0.163043478
35.8975,5.1575,20.01,3.6525,53.9876699,7.92815534,25.275,500,10,120,0.169734513
26.52,6.24,4.08,20.24,40.3,8.7,42,700,10,60,0.17
35.5,5.5,5.1,27.2,63,12.1,24.9,600,10,60,0.17
41.185,5.59,20.61,1.94,81.605,7.7225,36.94709524,500,15,120,0.174592617
41.26,5.425,31.3,3.195,53.9876699,7.92815534,18.82,500,10,120,0.179504814
22.3,3.6,2.8,16.6,41,5.7,53.3,500,10,60,0.18
46.615,5.6925,22.59,2.7375,53.9876699,7.92815534,12.365,500,10,120,0.182315668
37.78,5.47,28.79,2.13,77.69,6.485,36.94709524,400,15,120,0.188382412
9.1,0.86,1.5,10.2,18.8,2.8,78.4,700,20,95,0.197368421
52.3,8,6.7,32.3,62.3,6.5,31.2,1040,74.3,24,0.2
41.7,5.8,3,22.7,64.6,8.7,26,600,20,30,0.2
41.61,5.19,5.61,26.01,64.15,15.57,20.28,600,20,30,0.219606579
41.7,5.8,3,22.7,64.6,8.7,26,500,20,25,0.22
21.5,3.3,3.6,16.8,43.3,1.9,54.8,600,20,90,0.220588235
21.5,3.3,3.6,16.8,43.3,1.9,54.8,500,20,85,0.224522293
23.97,3.59,4.27,5.33,37.93,4.03,58.04,450,9,50,0.24
20.3,2.7,3.4,27,50.3,3.9,45.8,400,20,80,0.241189427
32.074,4.979,22.859,3.565,53.9876699,7.92815534,36.523,500,10,120,0.245472837
41.7,5.8,3,22.7,64.6,8.7,26,400,20,20,0.25
37.78,5.47,28.79,2.13,77.69,6.485,36.94709524,600,15,120,0.250607198
17.6,2.5,3.1,16.1,34.1,5.9,60,500,20,85,0.263157895
34.375,5.35,26.97,2.32,73.775,5.2475,36.94709524,400,15,120,0.265535313
9.1,0.86,1.5,10.2,18.8,2.8,78.4,600,20,90,0.27
17.6,2.5,3.1,16.1,34.1,5.9,60,400,20,80,0.270220588
33.292,5.102,25.532,3.17,53.9876699,7.92815534,32.904,500,10,120,0.271174377
26.52,6.24,4.08,20.24,40.3,8.7,42,400,10,60,0.28
35.5,5.5,5.1,27.2,63,12.1,24.9,500,10,60,0.28
24.67,4.65,4.51,18.6,51.44,5.68,46.62,300,15,40,0.29
34.375,5.35,26.97,2.32,73.775,5.2475,36.94709524,600,15,120,0.296411856
30.2,4.3,4.8,21.4,48.8,3.1,38.1,450,45,10,0.3
22.3,3.6,2.8,16.6,41,5.7,53.3,600,10,60,0.3
47.07,3.91,2.29,44.4,70.4,6.21,11.84,600,25,25,0.3
25.5,4.5,4.9,25.9,54.2,8.6,37.2,500,50,10,0.3
12.79,1.74,1.2,16.22,29.01,3.49,67.5,600,6.666666667,90,0.3
17.6,2.5,3.1,16.1,34.1,5.9,60,300,20,75,0.315088757
34.51,5.225,28.205,2.775,53.9876699,7.92815534,29.285,500,10,120,0.318587106
21.5,3.3,3.6,16.8,43.3,1.9,54.8,300,20,75,0.32
38.5,5.08,1.33,39.49,62.85,22.74,14.4,500,5.5,30,0.321896
38.5,5.08,1.33,39.49,62.85,22.74,14.4,600,5.5,30,0.328737
29.63,5.3,5.11,24.41,60.89,5.04,34.27,300,10,30,0.33
20.3,2.7,3.4,27,50.3,3.9,45.8,300,20,75,0.331395349
29.63,5.3,5.11,24.41,60.89,5.04,34.27,350,10,30,0.36
35.7,5.2,3.5,25.4,60.7,4.6,29.5,850,60,14.16666667,0.380067568
29.63,5.3,5.11,24.41,60.89,5.04,34.27,500,10,30,0.39
24.9,3.36,0.21,44.31,63.5,9.72,26.78,800,10,30,0.39
46.615,5.6925,22.59,2.7375,53.9876699,7.92815534,12.365,300,10,120,0.390319258
41.26,5.425,31.3,3.195,53.9876699,7.92815534,18.82,300,10,120,0.395809611
41.7,5.8,3,22.7,64.6,8.7,26,300,20,15,0.41
21.5,3.3,3.6,16.8,43.3,1.9,54.8,400,20,80,0.42
35.8975,5.1575,20.01,3.6525,53.9876699,7.92815534,25.275,300,10,120,0.428995253
9.1,0.86,1.5,10.2,18.8,2.8,78.4,500,20,85,0.434210526
12.18,5.82,1.26,23.06,27.12,16.65,56.23,900,35,25.71,0.44
29.63,5.3,5.11,24.41,60.89,5.04,34.27,450,10,30,0.46
29.63,5.3,5.11,24.41,60.89,5.04,34.27,400,10,30,0.47
9.1,0.86,1.5,10.2,18.8,2.8,78.4,400,20,80,0.477272727
24.67,4.65,4.51,18.6,51.44,5.68,46.62,700,15,40,0.5
47.07,3.91,2.29,44.4,70.4,6.21,11.84,400,25,25,0.5
38.5,5.08,1.33,39.49,62.85,22.74,14.4,400,5.5,30,0.513459
9.1,0.86,1.5,10.2,18.8,2.8,78.4,300,20,75,0.564705882
17.6,2.5,3.1,16.1,34.1,5.9,60,700,20,95,0.6
37.78,5.47,28.79,2.13,77.69,6.485,36.94709524,500,15,120,0.632675847
35.8975,5.1575,30.01,3.6525,53.9876699,7.92815534,25.275,700,10,120,0.676392573
12.18,5.82,1.26,23.06,27.12,16.65,56.23,700,35,20,0.686063218
42.1,6.1,6.4,27.3,63.5,11.6,16.6,450,20,52.5,0.7
42.1,6.1,6.4,27.3,63.5,11.6,16.6,525,20,56.25,0.7
42.1,6.1,6.4,27.3,63.5,11.6,16.6,600,20,60,0.7
23.97,3.59,4.27,5.33,37.93,4.03,58.04,750,15,50,0.7
20.3,2.7,3.4,27,50.3,3.9,45.8,700,20,95,0.708661417
42.1,6.1,6.4,27.3,63.5,11.6,16.6,375,20,48.75,0.71
42.1,6.1,6.4,27.3,63.5,11.6,16.6,600,30,50,0.74
41.26,5.425,31.3,3.195,53.9876699,7.92815534,18.82,700,10,120,0.753336203
42.1,6.1,6.4,27.3,63.5,11.6,16.6,450,30,45,0.77
42.1,6.1,6.4,27.3,63.5,11.6,16.6,525,30,47.5,0.78
36.52,5.33,5.18,23.15,65.07,7.65,27.69,700,10,60,0.8
21.5,3.3,3.6,16.8,43.3,1.9,54.8,700,20,95,0.803571429
12.18,5.82,1.26,23.06,27.12,16.65,56.23,500,35,14.28,0.831709477
30.856,4.856,20.186,3.96,53.9876699,7.92815534,40.142,500,10,120,0.850354314
46.615,5.6925,32.59,2.7375,53.9876699,7.92815534,12.365,700,10,120,0.855880729
30.2,4.3,4.8,21.4,48.8,3.1,38.1,400,40,10,0.9
41.185,5.59,30.61,1.94,81.605,7.7225,36.94709524,700,15,120,0.972177806
17.6,2.5,3.1,16.1,34.1,5.9,60,600,20,90,0.975
42.1,6.1,6.4,27.3,63.5,11.6,16.6,300,10,60,1.29
42.1,6.1,6.4,27.3,63.5,11.6,16.6,375,10,67.5,1.3
42.1,6.1,6.4,27.3,63.5,11.6,16.6,450,10,75,1.32
42.1,6.1,6.4,27.3,63.5,11.6,16.6,525,10,82.5,1.36
30.54,2.2,1.44,8.05,23.66,19.36,56.98,850,20.73170732,41,1.454914722
30.54,2.2,1.44,8.05,23.66,19.36,56.98,650,15.85365854,41,1.481535649
42.1,6.1,6.4,27.3,63.5,11.6,16.6,600,10,90,1.5
30.54,2.2,1.44,8.05,23.66,19.36,56.98,450,10.97560976,41,1.550840203
42.1,6.1,6.4,27.3,63.5,11.6,16.6,300,20,45,1.75
23.97,3.59,4.27,3.53,37.93,4.03,58.04,450,10.97560976,41,2.415956014
23.97,3.59,4.27,3.53,37.93,4.03,58.04,650,15.85365854,41,2.948008277
,,,,,,,,,,
,,,,,,,,,,
52.3,8,8.948663793,44.4,81.605,22.74,78.4,1040,74.3,120,
9.1,0.86,0.21,1.94,18.8,1.9,11.84,300,5,7.5,
,,,,,,,,,,
,,,,,,,,,,
9.1-52.3,0.86-8,0.21-8.95,1.94-44.4,18.8-81.605,1.9-22.74,11.84-78.4,300-800,5-74,,
{
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# 文本分类实例"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step1 导入相关包"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments\n",
"from datasets import load_dataset"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step2 加载数据集"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dataset = load_dataset(\"csv\", data_files=\"./ChnSentiCorp_htl_all.csv\", split=\"train\")\n",
"dataset = dataset.filter(lambda x: x[\"review\"] is not None)\n",
"dataset"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step3 划分数据集"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"datasets = dataset.train_test_split(test_size=0.1)\n",
"datasets"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step4 数据集预处理"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(\"hfl/rbt3\")\n",
"\n",
"def process_function(examples):\n",
" tokenized_examples = tokenizer(examples[\"review\"], max_length=128, truncation=True)\n",
" tokenized_examples[\"labels\"] = examples[\"label\"]\n",
" return tokenized_examples\n",
"\n",
"tokenized_datasets = datasets.map(process_function, batched=True, remove_columns=datasets[\"train\"].column_names)\n",
"tokenized_datasets"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step5 创建模型"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def model_init():\n",
" model = AutoModelForSequenceClassification.from_pretrained(\"hfl/rbt3\")\n",
" return model"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step6 创建评估函数"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import evaluate\n",
"\n",
"acc_metric = evaluate.load(\"accuracy\")\n",
"f1_metirc = evaluate.load(\"f1\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def eval_metric(eval_predict):\n",
" predictions, labels = eval_predict\n",
" predictions = predictions.argmax(axis=-1)\n",
" acc = acc_metric.compute(predictions=predictions, references=labels)\n",
" f1 = f1_metirc.compute(predictions=predictions, references=labels)\n",
" acc.update(f1)\n",
" return acc"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step7 创建TrainingArguments"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_args = TrainingArguments(output_dir=\"./checkpoints\", # 输出文件夹\n",
" per_device_train_batch_size=64, # 训练时的batch_size\n",
" per_device_eval_batch_size=128, # 验证时的batch_size\n",
" logging_steps=500, # log 打印的频率\n",
" evaluation_strategy=\"epoch\", # 评估策略\n",
" save_strategy=\"epoch\", # 保存策略\n",
" save_total_limit=3, # 最大保存数\n",
" learning_rate=2e-5, # 学习率\n",
" weight_decay=0.01, # weight_decay\n",
" metric_for_best_model=\"f1\", # 设定评估指标\n",
" load_best_model_at_end=True) # 训练完成后加载最优模型"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step8 创建Trainer"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import DataCollatorWithPadding\n",
"trainer = Trainer(model_init=model_init, \n",
" args=train_args, \n",
" train_dataset=tokenized_datasets[\"train\"], \n",
" eval_dataset=tokenized_datasets[\"test\"], \n",
" data_collator=DataCollatorWithPadding(tokenizer=tokenizer),\n",
" compute_metrics=eval_metric)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step9 模型训练"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# trainer.train()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step9 模型训练(自动搜索)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def default_hp_space_optuna(trial):\n",
" return {\n",
" \"learning_rate\": trial.suggest_float(\"learning_rate\", 1e-6, 1e-4),\n",
" \"num_train_epochs\": trial.suggest_int(\"num_train_epochs\", 1, 5),\n",
" \"seed\": trial.suggest_int(\"seed\", 1, 40),\n",
" \"per_device_train_batch_size\": trial.suggest_categorical(\"per_device_train_batch_size\", [4, 8, 16, 32, 64]),\n",
" \"optim\": trial.suggest_categorical(\"optim\", [\"sgd\", \"adamw_hf\"]),\n",
" }\n",
"\n",
"trainer.hyperparameter_search(hp_space=default_hp_space_optuna, compute_objective=lambda x: x[\"eval_f1\"], direction=\"maximize\", n_trials=10)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "transformers",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
{
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "initial_id",
"metadata": {
"collapsed": true,
"ExecuteTime": {
"end_time": "2024-12-05T14:10:02.032031Z",
"start_time": "2024-12-05T14:09:26.615762Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"训练集样本数: 77\n",
"验证集样本数: 20\n",
"测试集样本数: 20\n",
"Fitting 5 folds for each of 270 candidates, totalling 1350 fits\n",
"Best parameters found: {'bootstrap': False, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 100}\n",
"训练集 平均误差率为: 1.2754501965970513e-13 R2为: 1.0\n",
"验证集 平均误差率为: 81.74052460662605 R2为: 0.5743786212193724\n",
"测试集 平均误差率为: 63.459346427509914 R2为: 0.2492545962759135\n",
" 真实值 rf_预测 rf_误差率\n",
"0 0.700000 0.700000 1.903239e-13\n",
"1 1.360000 1.360000 1.632681e-14\n",
"2 0.434211 0.434211 1.406282e-13\n",
"3 0.270000 0.270000 1.439178e-13\n",
"4 0.800000 0.800000 1.942890e-13\n",
".. ... ... ...\n",
"72 1.550840 1.550840 2.577186e-13\n",
"73 0.315089 0.315089 2.290291e-13\n",
"74 0.420000 0.420000 2.379049e-13\n",
"75 1.290000 1.290000 2.237659e-13\n",
"76 0.700000 0.700000 1.903239e-13\n",
"\n",
"[77 rows x 3 columns]\n",
" 真实值 rf_预测 rf_误差率\n",
"0 0.700000 0.384019 45.140109\n",
"1 0.169735 0.381743 124.905791\n",
"2 0.130000 0.143376 10.288865\n",
"3 1.454915 1.453441 0.101268\n",
"4 0.161218 0.183198 13.633632\n",
"5 0.686063 0.832028 21.275638\n",
"6 0.300000 0.849481 183.160380\n",
"7 0.770000 1.154617 49.950272\n",
"8 1.481536 1.514061 2.195402\n",
"9 0.390000 0.382865 1.829540\n",
"10 0.300000 0.189459 36.847012\n",
"11 0.179505 0.339052 88.881954\n",
"12 0.133333 0.316901 137.675474\n",
"13 0.224522 0.376518 67.697221\n",
"14 0.130000 0.894200 587.846154\n",
"15 0.130240 0.345293 165.121638\n",
"16 0.564706 0.464691 17.710985\n",
"17 0.500000 0.374682 25.063608\n",
"18 0.158505 0.231296 45.923637\n",
"19 0.780000 0.705417 9.561911\n",
" 真实值 rf_预测 rf_误差率\n",
"0 0.250607 0.139298 44.415797\n",
"1 0.107143 0.339530 216.895076\n",
"2 0.296412 0.116847 60.579472\n",
"3 0.245473 0.230064 6.277053\n",
"4 0.113930 0.173835 52.579898\n",
"5 0.632676 0.161976 74.398320\n",
"6 0.410000 0.272436 33.552115\n",
"7 0.740000 0.704717 4.767960\n",
"8 0.219607 0.423632 92.904890\n",
"9 0.120000 0.152948 27.456707\n",
"10 0.240000 0.657683 174.034492\n",
"11 0.170000 0.334165 96.567408\n",
"12 0.753336 0.633392 15.921783\n",
"13 0.972178 0.168735 82.643621\n",
"14 0.141719 0.184504 30.189690\n",
"15 0.100000 0.121528 21.527884\n",
"16 0.321896 0.487598 51.476850\n",
"17 1.500000 1.254400 16.373333\n",
"18 0.440000 0.820681 86.518346\n",
"19 0.850354 0.169167 80.106234\n",
"33.53351902961731\n"
]
}
],
"source": [
"# 导入所需的库\n",
"import csv\n",
"import time\n",
"\n",
"import joblib\n",
"import numpy as np\n",
"import pandas as pd\n",
"from sklearn.metrics import mean_squared_error, r2_score\n",
"from sklearn.ensemble import RandomForestRegressor\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.preprocessing import MinMaxScaler\n",
"import warnings\n",
"from sklearn.model_selection import GridSearchCV\n",
"import time\n",
"# 忽略警告信息\n",
"warnings.filterwarnings(\"ignore\")\n",
"\n",
"# 数据加载函数\n",
"def load_data(file_path):\n",
" \"\"\"从CSV文件中加载数据,并转换为浮点数列表\"\"\"\n",
" data = []\n",
" with open(file_path, 'r', encoding='utf-8-sig') as file:\n",
" reader = csv.reader(file)\n",
" for line in reader:\n",
" data.append(line)\n",
" return np.array(data, dtype='float64').tolist()\n",
"\n",
"# 误差率计算函数\n",
"def calculate_error_rate(predictions, labels):\n",
" \"\"\"计算预测值与真实值之间的误差率(百分比)\"\"\"\n",
" return [np.abs((p - l) / l) * 100 if l != 0 else 0 for p, l in zip(predictions, labels)]\n",
"\n",
"# 自定义评分函数:均方根误差\n",
"def custom_rmse(y_true, y_pred):\n",
" \"\"\"计算均方根误差\"\"\"\n",
" mse = mean_squared_error(y_true, y_pred)\n",
" rmse = mse ** 0.5\n",
" return rmse\n",
"\n",
"# 打印预测结果和性能指标\n",
"def print_results(model, cross_data, test_data, train_data, val_labels, test_labels, train_labels):\n",
" \"\"\"打印训练集、验证集和测试集的预测结果及性能指标\"\"\"\n",
" # 预测\n",
" label_predict_cross = model.predict(cross_data)\n",
" label_predict_test = model.predict(test_data)\n",
" label_predict_train = model.predict(train_data)\n",
"\n",
" # 构建DataFrame存储真实值和预测值\n",
" df_cross = pd.DataFrame({'真实值': [i[0] for i in val_labels], 'rf_预测': label_predict_cross})\n",
" df_test = pd.DataFrame({'真实值': [i[0] for i in test_labels], 'rf_预测': label_predict_test})\n",
" df_train = pd.DataFrame({'真实值': [i[0] for i in train_labels], 'rf_预测': label_predict_train})\n",
"\n",
" # 计算误差率\n",
" df_cross['rf_误差率'] = calculate_error_rate(df_cross['rf_预测'], df_cross['真实值'])\n",
" df_test['rf_误差率'] = calculate_error_rate(df_test['rf_预测'], df_test['真实值'])\n",
" df_train['rf_误差率'] = calculate_error_rate(df_train['rf_预测'], df_train['真实值'])\n",
"\n",
" # 计算R²分数\n",
" r2_cross = r2_score(df_cross['真实值'], df_cross['rf_预测'])\n",
" r2_test = r2_score(df_test['真实值'], df_test['rf_预测'])\n",
" r2_train = r2_score(df_train['真实值'], df_train['rf_预测'])\n",
"\n",
" # 打印结果\n",
" print(f\"训练集 平均误差率为: {df_train['rf_误差率'].mean()} R2为: {r2_train}\")\n",
" print(f\"验证集 平均误差率为: {df_cross['rf_误差率'].mean()} R2为: {r2_cross}\")\n",
" print(f\"测试集 平均误差率为: {df_test['rf_误差率'].mean()} R2为: {r2_test}\")\n",
"\n",
" # 打印DataFrame\n",
" print(df_train)\n",
" print(df_cross)\n",
" print(df_test)\n",
"\n",
"# 主函数\n",
"def main():\n",
" # 文件路径\n",
" s=time.time()\n",
" input_file = r'E:\\点头第五期课程\\模型参数优化\\OC\\OC ratio-inputs.csv'\n",
" output_file = r'E:\\点头第五期课程\\模型参数优化\\OC\\OC ratio-outputs.csv'\n",
"\n",
" # 加载数据\n",
" pattern = load_data(input_file)\n",
" label = load_data(output_file)\n",
"\n",
" # 划分数据集\n",
" temp_data, test_data, temp_labels, test_labels = train_test_split(\n",
" pattern, label, test_size=0.17, random_state=42)\n",
" train_data, val_data, train_labels, val_labels = train_test_split(\n",
" temp_data, temp_labels, test_size=0.2, random_state=42)\n",
"\n",
" # 输出数据集大小\n",
" print(\"训练集样本数:\", len(train_data))\n",
" print(\"验证集样本数:\", len(val_data))\n",
" print(\"测试集样本数:\", len(test_data))\n",
"\n",
" # 数据归一化\n",
" scaler = MinMaxScaler()\n",
" normalized_data = scaler.fit_transform(train_data)\n",
" normalized_cross_data = scaler.transform(val_data)\n",
" normalized_test_data = scaler.transform(test_data)\n",
"\n",
" # 定义参数网格\n",
" param_grid = {\n",
" 'n_estimators': [100, 200, 300,400,500],\n",
" 'max_features': ['sqrt', 'log2',None],\n",
" 'min_samples_split': [2, 5, 10],\n",
" 'min_samples_leaf': [1, 2, 4],\n",
" 'bootstrap': [True, False]\n",
" }\n",
"\n",
" # 创建随机森林模型实例\n",
" rf_model = RandomForestRegressor(random_state=42)\n",
"\n",
" # 使用GridSearchCV进行网格搜索\n",
" grid_search = GridSearchCV(estimator=rf_model, param_grid=param_grid, cv=5, scoring='neg_mean_squared_error', n_jobs=-1, verbose=2)\n",
" \n",
" # 训练模型并执行网格搜索\n",
" grid_search.fit(normalized_data, train_labels)\n",
"\n",
" # 输出最佳参数和最佳得分\n",
" print(\"Best parameters found: \", grid_search.best_params_)\n",
" # print(\"Best cross-validation score: {:.2f}\".format(grid_search.best_score_))\n",
"\n",
" # 使用最佳参数创建最终模型\n",
" best_rf_model = grid_search.best_estimator_\n",
" # 打印预测结果和性能指标\n",
" print_results(best_rf_model, normalized_cross_data, normalized_test_data, normalized_data, val_labels, test_labels, train_labels)\n",
" e=time.time()\n",
" print(e-s)\n",
"if __name__ == \"__main__\":\n",
" main()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
{
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "initial_id",
"metadata": {
"collapsed": true,
"ExecuteTime": {
"end_time": "2024-12-05T14:08:22.213978Z",
"start_time": "2024-12-05T14:08:04.149973Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"训练集样本数: 77\n",
"验证集样本数: 20\n",
"测试集样本数: 20\n",
"Fitting 5 folds for each of 1 candidates, totalling 5 fits\n",
"Fitting 5 folds for each of 1 candidates, totalling 5 fits\n",
"Fitting 5 folds for each of 1 candidates, totalling 5 fits\n",
"Fitting 5 folds for each of 1 candidates, totalling 5 fits\n",
"Fitting 5 folds for each of 1 candidates, totalling 5 fits\n",
"Fitting 5 folds for each of 1 candidates, totalling 5 fits\n",
"Fitting 5 folds for each of 1 candidates, totalling 5 fits\n",
"Fitting 5 folds for each of 1 candidates, totalling 5 fits\n",
"Fitting 5 folds for each of 1 candidates, totalling 5 fits\n",
"Fitting 5 folds for each of 1 candidates, totalling 5 fits\n",
"Fitting 5 folds for each of 1 candidates, totalling 5 fits\n",
"Fitting 5 folds for each of 1 candidates, totalling 5 fits\n",
"Fitting 5 folds for each of 1 candidates, totalling 5 fits\n",
"Fitting 5 folds for each of 1 candidates, totalling 5 fits\n",
"Fitting 5 folds for each of 1 candidates, totalling 5 fits\n",
"Fitting 5 folds for each of 1 candidates, totalling 5 fits\n",
"Fitting 5 folds for each of 1 candidates, totalling 5 fits\n",
"Fitting 5 folds for each of 1 candidates, totalling 5 fits\n",
"Fitting 5 folds for each of 1 candidates, totalling 5 fits\n",
"Fitting 5 folds for each of 1 candidates, totalling 5 fits\n",
"Best parameters found: OrderedDict([('bootstrap', True), ('max_features', 'log2'), ('min_samples_leaf', 5), ('min_samples_split', 8), ('n_estimators', 493)])\n",
"训练集 平均误差率为: 77.53119267362298 R2为: 0.38814354094871295\n",
"验证集 平均误差率为: 81.32371989192606 R2为: 0.47512928208216354\n",
"测试集 平均误差率为: 111.95118725946797 R2为: 0.10900854268492577\n",
" 真实值 rf_预测 rf_误差率\n",
"0 0.700000 0.884169 26.309846\n",
"1 1.360000 0.885138 34.916355\n",
"2 0.434211 0.594416 36.895694\n",
"3 0.270000 0.611973 126.656711\n",
"4 0.800000 0.419645 47.544419\n",
".. ... ... ...\n",
"72 1.550840 0.788595 49.150450\n",
"73 0.315089 0.581139 84.436585\n",
"74 0.420000 0.444108 5.740022\n",
"75 1.290000 0.929957 27.910300\n",
"76 0.700000 0.912329 30.332691\n",
"\n",
"[77 rows x 3 columns]\n",
" 真实值 rf_预测 rf_误差率\n",
"0 0.700000 0.786603 12.371850\n",
"1 0.169735 0.399930 135.621199\n",
"2 0.130000 0.259675 99.750272\n",
"3 1.454915 0.735751 49.429934\n",
"4 0.161218 0.320441 98.762551\n",
"5 0.686063 0.716669 4.461087\n",
"6 0.300000 0.614363 104.787634\n",
"7 0.770000 0.916765 19.060339\n",
"8 1.481536 0.747330 49.557075\n",
"9 0.390000 0.362478 7.056975\n",
"10 0.300000 0.366718 22.239286\n",
"11 0.179505 0.416753 132.168121\n",
"12 0.133333 0.342350 156.762571\n",
"13 0.224522 0.429417 91.257902\n",
"14 0.130000 0.482574 271.211130\n",
"15 0.130240 0.344333 164.384274\n",
"16 0.564706 0.621899 10.128000\n",
"17 0.500000 0.311181 37.763884\n",
"18 0.158505 0.391424 146.947318\n",
"19 0.780000 0.879473 12.752995\n",
" 真实值 rf_预测 rf_误差率\n",
"0 0.250607 0.282672 12.794991\n",
"1 0.107143 0.407840 280.650773\n",
"2 0.296412 0.285887 3.550611\n",
"3 0.245473 0.376188 53.250158\n",
"4 0.113930 0.284760 149.942530\n",
"5 0.632676 0.271753 57.047038\n",
"6 0.410000 0.271471 33.787588\n",
"7 0.740000 0.879425 18.841244\n",
"8 0.219607 0.426237 94.091168\n",
"9 0.120000 0.494396 311.996997\n",
"10 0.240000 0.805267 235.527847\n",
"11 0.170000 0.508306 199.003644\n",
"12 0.753336 0.463927 38.417058\n",
"13 0.972178 0.317351 67.356738\n",
"14 0.141719 0.303130 113.894704\n",
"15 0.100000 0.457592 357.592091\n",
"16 0.321896 0.482197 49.799046\n",
"17 1.500000 0.885779 40.948074\n",
"18 0.440000 0.715868 62.697166\n",
"19 0.850354 0.358558 57.834280\n",
"16.043532609939575\n"
]
}
],
"source": [
"# 导入所需的库\n",
"import csv\n",
"import joblib\n",
"import numpy as np\n",
"import pandas as pd\n",
"from sklearn.metrics import mean_squared_error, r2_score\n",
"from sklearn.ensemble import RandomForestRegressor\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.preprocessing import MinMaxScaler\n",
"import warnings\n",
"import time\n",
"from skopt import BayesSearchCV\n",
"from skopt.space import Real, Integer, Categorical\n",
"# 忽略警告信息\n",
"warnings.filterwarnings(\"ignore\")\n",
"\n",
"# 数据加载函数\n",
"def load_data(file_path):\n",
" \"\"\"从CSV文件中加载数据,并转换为浮点数列表\"\"\"\n",
" data = []\n",
" with open(file_path, 'r', encoding='utf-8-sig') as file:\n",
" reader = csv.reader(file)\n",
" for line in reader:\n",
" data.append(line)\n",
" return np.array(data, dtype='float64').tolist()\n",
"\n",
"# 误差率计算函数\n",
"def calculate_error_rate(predictions, labels):\n",
" \"\"\"计算预测值与真实值之间的误差率(百分比)\"\"\"\n",
" return [np.abs((p - l) / l) * 100 if l != 0 else 0 for p, l in zip(predictions, labels)]\n",
"\n",
"# 自定义评分函数:均方根误差\n",
"def custom_rmse(y_true, y_pred):\n",
" \"\"\"计算均方根误差\"\"\"\n",
" mse = mean_squared_error(y_true, y_pred)\n",
" rmse = mse ** 0.5\n",
" return rmse\n",
"\n",
"# 打印预测结果和性能指标\n",
"def print_results(model, cross_data, test_data, train_data, val_labels, test_labels, train_labels):\n",
" \"\"\"打印训练集、验证集和测试集的预测结果及性能指标\"\"\"\n",
" # 预测\n",
" label_predict_cross = model.predict(cross_data)\n",
" label_predict_test = model.predict(test_data)\n",
" label_predict_train = model.predict(train_data)\n",
"\n",
" # 构建DataFrame存储真实值和预测值\n",
" df_cross = pd.DataFrame({'真实值': [i[0] for i in val_labels], 'rf_预测': label_predict_cross})\n",
" df_test = pd.DataFrame({'真实值': [i[0] for i in test_labels], 'rf_预测': label_predict_test})\n",
" df_train = pd.DataFrame({'真实值': [i[0] for i in train_labels], 'rf_预测': label_predict_train})\n",
"\n",
" # 计算误差率\n",
" df_cross['rf_误差率'] = calculate_error_rate(df_cross['rf_预测'], df_cross['真实值'])\n",
" df_test['rf_误差率'] = calculate_error_rate(df_test['rf_预测'], df_test['真实值'])\n",
" df_train['rf_误差率'] = calculate_error_rate(df_train['rf_预测'], df_train['真实值'])\n",
"\n",
" # 计算R²分数\n",
" r2_cross = r2_score(df_cross['真实值'], df_cross['rf_预测'])\n",
" r2_test = r2_score(df_test['真实值'], df_test['rf_预测'])\n",
" r2_train = r2_score(df_train['真实值'], df_train['rf_预测'])\n",
"\n",
" # 打印结果\n",
" print(f\"训练集 平均误差率为: {df_train['rf_误差率'].mean()} R2为: {r2_train}\")\n",
" print(f\"验证集 平均误差率为: {df_cross['rf_误差率'].mean()} R2为: {r2_cross}\")\n",
" print(f\"测试集 平均误差率为: {df_test['rf_误差率'].mean()} R2为: {r2_test}\")\n",
"\n",
" # 打印DataFrame\n",
" print(df_train)\n",
" print(df_cross)\n",
" print(df_test)\n",
"\n",
"# 主函数\n",
"def main():\n",
" s=time.time()\n",
" # 文件路径\n",
" input_file = r'E:\\点头第五期课程\\模型参数优化\\OC\\OC ratio-inputs.csv'\n",
" output_file = r'E:\\点头第五期课程\\模型参数优化\\OC\\OC ratio-outputs.csv'\n",
"\n",
" # 加载数据\n",
" pattern = load_data(input_file)\n",
" label = load_data(output_file)\n",
"\n",
" # 划分数据集\n",
" temp_data, test_data, temp_labels, test_labels = train_test_split(\n",
" pattern, label, test_size=0.17, random_state=42)\n",
" train_data, val_data, train_labels, val_labels = train_test_split(\n",
" temp_data, temp_labels, test_size=0.2, random_state=42)\n",
"\n",
" # 输出数据集大小\n",
" print(\"训练集样本数:\", len(train_data))\n",
" print(\"验证集样本数:\", len(val_data))\n",
" print(\"测试集样本数:\", len(test_data))\n",
"\n",
" # 数据归一化\n",
" scaler = MinMaxScaler()\n",
" normalized_data = scaler.fit_transform(train_data)\n",
" normalized_cross_data = scaler.transform(val_data)\n",
" normalized_test_data = scaler.transform(test_data)\n",
"\n",
" # 定义参数空间\n",
" param_space = {\n",
" 'n_estimators': Integer(100, 500), # 树的数量\n",
" 'max_features': Categorical(['sqrt', 'log2',None]),\n",
" 'min_samples_split': Integer(2, 11),\n",
" 'min_samples_leaf': Integer(1, 5),\n",
" 'bootstrap': Categorical([True, False])\n",
" }\n",
"\n",
" # 创建随机森林模型实例\n",
" rf_model = RandomForestRegressor(random_state=42)\n",
"\n",
" # 使用BayesSearchCV进行贝叶斯优化\n",
" bayes_search = BayesSearchCV(\n",
" estimator=rf_model,\n",
" search_spaces=param_space,\n",
" n_iter=20, # 迭代次数\n",
" cv=5, # 五折交叉验证\n",
" scoring='neg_mean_squared_error',\n",
" n_jobs=-1, # 使用所有可用的核心\n",
" verbose=2, # 打印详细信息\n",
" random_state=42 # 确保结果可复现\n",
" )\n",
"\n",
" # 训练模型并执行贝叶斯优化\n",
" bayes_search.fit(normalized_data, train_labels)\n",
"\n",
" # 输出最佳参数和最佳得分\n",
" print(\"Best parameters found: \", bayes_search.best_params_)\n",
" # print(\"Best cross-validation score: {:.2f}\".format(bayes_search.best_score_))\n",
"\n",
" # 使用最佳参数创建最终模型\n",
" best_rf_model = bayes_search.best_estimator_\n",
"\n",
" # 打印预测结果和性能指标\n",
" print_results(best_rf_model, normalized_cross_data, normalized_test_data, normalized_data, val_labels, test_labels, train_labels)\n",
" e=time.time()\n",
" print(e-s)\n",
"if __name__ == \"__main__\":\n",
" main()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
{
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "initial_id",
"metadata": {
"collapsed": true,
"ExecuteTime": {
"end_time": "2024-12-05T14:10:03.661876Z",
"start_time": "2024-12-05T14:09:40.330750Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"训练集样本数: 77\n",
"验证集样本数: 20\n",
"测试集样本数: 20\n",
"Fitting 5 folds for each of 10 candidates, totalling 50 fits\n",
"Best parameters found: {'bootstrap': True, 'max_features': 'sqrt', 'min_samples_leaf': 3, 'min_samples_split': 9, 'n_estimators': 288}\n",
"训练集 平均误差率为: 70.67592788833923 R2为: 0.5180963229443235\n",
"验证集 平均误差率为: 79.83781270986681 R2为: 0.5767077381613671\n",
"测试集 平均误差率为: 106.22934209295856 R2为: 0.07305682307731742\n",
" 真实值 rf_预测 rf_误差率\n",
"0 0.700000 0.920069 31.438479\n",
"1 1.360000 0.956882 29.641055\n",
"2 0.434211 0.499291 14.988126\n",
"3 0.270000 0.519835 92.531383\n",
"4 0.800000 0.450684 43.664517\n",
".. ... ... ...\n",
"72 1.550840 0.958987 38.163377\n",
"73 0.315089 0.514595 63.317513\n",
"74 0.420000 0.403999 3.809841\n",
"75 1.290000 1.001405 22.371698\n",
"76 0.700000 0.959099 37.014089\n",
"\n",
"[77 rows x 3 columns]\n",
" 真实值 rf_预测 rf_误差率\n",
"0 0.700000 0.895178 27.882609\n",
"1 0.169735 0.389751 129.624018\n",
"2 0.130000 0.273356 110.273906\n",
"3 1.454915 0.866331 40.454848\n",
"4 0.161218 0.328845 103.975032\n",
"5 0.686063 0.753261 9.794730\n",
"6 0.300000 0.618789 106.263098\n",
"7 0.770000 0.978986 27.141046\n",
"8 1.481536 0.895964 39.524670\n",
"9 0.390000 0.362949 6.936210\n",
"10 0.300000 0.344163 14.720926\n",
"11 0.179505 0.411381 129.175378\n",
"12 0.133333 0.338644 153.983317\n",
"13 0.224522 0.389204 73.347707\n",
"14 0.130000 0.488553 275.809710\n",
"15 0.130240 0.336868 158.652387\n",
"16 0.564706 0.525422 6.956602\n",
"17 0.500000 0.319249 36.150150\n",
"18 0.158505 0.356904 125.168832\n",
"19 0.780000 0.943184 20.921080\n",
" 真实值 rf_预测 rf_误差率\n",
"0 0.250607 0.254663 1.618358\n",
"1 0.107143 0.381388 255.962064\n",
"2 0.296412 0.281428 5.054913\n",
"3 0.245473 0.354115 44.258263\n",
"4 0.113930 0.256730 125.339815\n",
"5 0.632676 0.240955 61.914967\n",
"6 0.410000 0.260137 36.551946\n",
"7 0.740000 0.920349 24.371516\n",
"8 0.219607 0.437816 99.363807\n",
"9 0.120000 0.441586 267.987946\n",
"10 0.240000 0.924158 285.065920\n",
"11 0.170000 0.517921 204.659134\n",
"12 0.753336 0.472953 37.218925\n",
"13 0.972178 0.300051 69.136150\n",
"14 0.141719 0.267084 88.460127\n",
"15 0.100000 0.405287 305.287252\n",
"16 0.321896 0.455695 41.565796\n",
"17 1.500000 0.942020 37.198637\n",
"18 0.440000 0.753261 71.195740\n",
"19 0.850354 0.319941 62.375565\n",
"18.86562991142273\n"
]
}
],
"source": [
"# 导入所需的库\n",
"import csv\n",
"import joblib\n",
"import numpy as np\n",
"import pandas as pd\n",
"from sklearn.metrics import mean_squared_error, r2_score\n",
"from sklearn.ensemble import RandomForestRegressor\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.preprocessing import MinMaxScaler\n",
"import warnings\n",
"import time\n",
"from sklearn.model_selection import RandomizedSearchCV\n",
"from scipy.stats import randint as sp_randint\n",
"# 忽略警告信息\n",
"warnings.filterwarnings(\"ignore\")\n",
"\n",
"# 数据加载函数\n",
"def load_data(file_path):\n",
" \"\"\"从CSV文件中加载数据,并转换为浮点数列表\"\"\"\n",
" data = []\n",
" with open(file_path, 'r', encoding='utf-8-sig') as file:\n",
" reader = csv.reader(file)\n",
" for line in reader:\n",
" data.append(line)\n",
" return np.array(data, dtype='float64').tolist()\n",
"\n",
"# 误差率计算函数\n",
"def calculate_error_rate(predictions, labels):\n",
" \"\"\"计算预测值与真实值之间的误差率(百分比)\"\"\"\n",
" return [np.abs((p - l) / l) * 100 if l != 0 else 0 for p, l in zip(predictions, labels)]\n",
"\n",
"# 自定义评分函数:均方根误差\n",
"def custom_rmse(y_true, y_pred):\n",
" \"\"\"计算均方根误差\"\"\"\n",
" mse = mean_squared_error(y_true, y_pred)\n",
" rmse = mse ** 0.5\n",
" return rmse\n",
"\n",
"# 打印预测结果和性能指标\n",
"def print_results(model, cross_data, test_data, train_data, val_labels, test_labels, train_labels):\n",
" \"\"\"打印训练集、验证集和测试集的预测结果及性能指标\"\"\"\n",
" # 预测\n",
" label_predict_cross = model.predict(cross_data)\n",
" label_predict_test = model.predict(test_data)\n",
" label_predict_train = model.predict(train_data)\n",
"\n",
" # 构建DataFrame存储真实值和预测值\n",
" df_cross = pd.DataFrame({'真实值': [i[0] for i in val_labels], 'rf_预测': label_predict_cross})\n",
" df_test = pd.DataFrame({'真实值': [i[0] for i in test_labels], 'rf_预测': label_predict_test})\n",
" df_train = pd.DataFrame({'真实值': [i[0] for i in train_labels], 'rf_预测': label_predict_train})\n",
"\n",
" # 计算误差率\n",
" df_cross['rf_误差率'] = calculate_error_rate(df_cross['rf_预测'], df_cross['真实值'])\n",
" df_test['rf_误差率'] = calculate_error_rate(df_test['rf_预测'], df_test['真实值'])\n",
" df_train['rf_误差率'] = calculate_error_rate(df_train['rf_预测'], df_train['真实值'])\n",
"\n",
" # 计算R²分数\n",
" r2_cross = r2_score(df_cross['真实值'], df_cross['rf_预测'])\n",
" r2_test = r2_score(df_test['真实值'], df_test['rf_预测'])\n",
" r2_train = r2_score(df_train['真实值'], df_train['rf_预测'])\n",
"\n",
" # 打印结果\n",
" print(f\"训练集 平均误差率为: {df_train['rf_误差率'].mean()} R2为: {r2_train}\")\n",
" print(f\"验证集 平均误差率为: {df_cross['rf_误差率'].mean()} R2为: {r2_cross}\")\n",
" print(f\"测试集 平均误差率为: {df_test['rf_误差率'].mean()} R2为: {r2_test}\")\n",
"\n",
" # 打印DataFrame\n",
" print(df_train)\n",
" print(df_cross)\n",
" print(df_test)\n",
"\n",
"# 主函数\n",
"def main():\n",
" s=time.time()\n",
" # 文件路径\n",
" input_file = r'E:\\点头第五期课程\\模型参数优化\\OC\\OC ratio-inputs.csv'\n",
" output_file = r'E:\\点头第五期课程\\模型参数优化\\OC\\OC ratio-outputs.csv'\n",
"\n",
" # 加载数据\n",
" pattern = load_data(input_file)\n",
" label = load_data(output_file)\n",
"\n",
" # 划分数据集\n",
" temp_data, test_data, temp_labels, test_labels = train_test_split(\n",
" pattern, label, test_size=0.17, random_state=42)\n",
" train_data, val_data, train_labels, val_labels = train_test_split(\n",
" temp_data, temp_labels, test_size=0.2, random_state=42)\n",
"\n",
" # 输出数据集大小\n",
" print(\"训练集样本数:\", len(train_data))\n",
" print(\"验证集样本数:\", len(val_data))\n",
" print(\"测试集样本数:\", len(test_data))\n",
"\n",
" # 数据归一化\n",
" scaler = MinMaxScaler()\n",
" normalized_data = scaler.fit_transform(train_data)\n",
" normalized_cross_data = scaler.transform(val_data)\n",
" normalized_test_data = scaler.transform(test_data)\n",
"\n",
"# 定义参数空间\n",
" param_dist = {\n",
" 'n_estimators': sp_randint(100, 500), # 树的数量\n",
" 'max_features': ['sqrt', 'log2',None],\n",
" 'min_samples_split': sp_randint(2, 11),\n",
" 'min_samples_leaf': sp_randint(1, 5),\n",
" 'bootstrap': [True, False]\n",
" }\n",
"\n",
" # 创建随机森林模型实例\n",
" rf_model = RandomForestRegressor(random_state=42)\n",
"\n",
" # 使用RandomizedSearchCV进行随机搜索\n",
" random_search = RandomizedSearchCV(\n",
" estimator=rf_model,\n",
" param_distributions=param_dist,\n",
" n_iter=10, # 随机采样次数\n",
" cv=5, # 五折交叉验证\n",
" scoring='neg_mean_squared_error',\n",
" n_jobs=-1, # 使用所有可用的核心\n",
" verbose=2, # 打印详细信息\n",
" random_state=42 # 确保结果可复现\n",
" )\n",
"\n",
" # 训练模型并执行随机搜索\n",
" random_search.fit(normalized_data, train_labels)\n",
"\n",
" # 输出最佳参数和最佳得分\n",
" print(\"Best parameters found: \", random_search.best_params_)\n",
" # print(\"Best cross-validation score: {:.2f}\".format(random_search.best_score_))\n",
"\n",
" # 使用最佳参数创建最终模型\n",
" best_rf_model = random_search.best_estimator_\n",
" # 打印预测结果和性能指标\n",
" print_results(best_rf_model, normalized_cross_data, normalized_test_data, normalized_data, val_labels, test_labels, train_labels)\n",
" e=time.time()\n",
" print(e-s)\n",
"if __name__ == \"__main__\":\n",
" main()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
This source diff could not be displayed because it is too large. You can view the blob instead.
# 图像模型优化补充资料
# 图像模型优化补充资料
## 下载链接
请通过以下链接下载补充资料:
- **链接**: [https://pan.baidu.com/s/1EQuRDiDl57kFgGYLNtTh8w?pwd=q2cj](https://pan.baidu.com/s/1EQuRDiDl57kFgGYLNtTh8w?pwd=q2cj)
- **提取码**: q2cj
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment