Commit dfc0df4b by 前钰

Upload New File

parent 2a375f53
{
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "9a1ddf95",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "ddca5972",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([9, 32, 8])\n",
"torch.Size([2, 3, 32, 8])\n"
]
}
],
"source": [
"#[class1-3, student, scores]\n",
"\n",
"a = torch.rand(3, 32, 8)\n",
"b = torch.rand(6, 32, 8)\n",
"c = torch.rand(3, 32, 8)\n",
"\n",
"print(torch.cat([a, b], dim=0).shape)\n",
"\n",
"print(torch.stack([a,c], dim =0).shape)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "d0fd26fc",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([2, 32, 8])\n",
"torch.Size([2, 32, 8])\n",
"torch.Size([1, 32, 8])\n",
"3\n",
"torch.Size([3, 32, 8])\n",
"torch.Size([2, 32, 8])\n",
"2\n"
]
}
],
"source": [
"a = torch.rand(5, 32, 8)\n",
"b = torch.split(a, 2, 0)\n",
"print(b[0].shape)\n",
"print(b[1].shape)\n",
"print(b[2].shape)\n",
"print(len(b))\n",
"\n",
"\n",
"c = torch.chunk(a, 2, 0)\n",
"print(c[0].shape)\n",
"print(c[1].shape)\n",
"print(len(c))"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "5b02c65c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(3.)\n",
"tensor(4.)\n",
"tensor(3.)\n",
"tensor(3.)\n",
"tensor(0.1416)\n"
]
}
],
"source": [
"a = torch.tensor(3.1415926)\n",
"\n",
"# floor\n",
"print(a.floor())\n",
"\n",
"# ceil\n",
"print(a.ceil())\n",
"\n",
"# round\n",
"print(a.round())\n",
"\n",
"# trunc\n",
"print(a.trunc())\n",
"\n",
"# frac\n",
"print(a.frac())"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "b214751c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([1., 2., 3., 4., 5., 6., 7.])\n",
"tensor(4.)\n",
"tensor(7.)\n",
"tensor(1.)\n",
"tensor(28.)\n",
"tensor(5040.)\n",
"tensor(6)\n",
"tensor(0)\n"
]
}
],
"source": [
"a = torch.tensor([1.,2.,3.,4.,5.,6.,7.])\n",
"print(a)\n",
"\n",
"print(a.mean())\n",
"print(a.max())\n",
"print(a.min())\n",
"print(a.sum())\n",
"print(a.prod())\n",
"\n",
"# argmax / argmin\n",
"print(a.argmax())\n",
"print(a.argmin())"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "2f89ff69",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[1., 1., 1.],\n",
" [1., 1., 1.],\n",
" [1., 1., 1.]])\n",
"tensor([[1., 0., 0.],\n",
" [0., 1., 0.],\n",
" [0., 0., 1.]])\n",
"tensor([[ True, False, False],\n",
" [False, True, False],\n",
" [False, False, True]])\n",
"False\n"
]
}
],
"source": [
"a = torch.ones(3,3)\n",
"b = torch.eye(3,3)\n",
"print(a)\n",
"print(b)\n",
"\n",
"print(torch.eq(a,b))\n",
"print(torch.equal(a, b))"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "14e15af9",
"metadata": {},
"outputs": [],
"source": [
"# 基于pytorch 实现 手写数字的识别问题 mnist"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "d40ebab2",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torchvision\n",
"import torchvision.transforms as transforms\n",
"import torch.optim as optim"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "d4a79df5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.22868333333333332 0.3601\n",
"0.47163333333333335 0.5569\n",
"0.5959666666666666 0.6348\n",
"0.6523833333333333 0.6823\n",
"0.68475 0.7109\n",
"done\n"
]
}
],
"source": [
"if torch.cuda.is_available():\n",
" device = 'cuda'\n",
"else: \n",
" device = 'cpu'\n",
"\n",
"class Net(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" # 28*28 = 784\n",
" self.fc1 = nn.Linear(784, 100)\n",
" self.fc2 = nn.Linear(100, 10)\n",
" # hook\n",
" def forward(self, x):\n",
" x = torch.flatten(x, start_dim = 1)\n",
" x = torch.relu(self.fc1(x))\n",
" x = self.fc2(x)\n",
" \n",
" return x\n",
" \n",
" \n",
"max_epochs = 5\n",
"batch_size = 16\n",
"\n",
"# data\n",
"transform = transforms.Compose([transforms.ToTensor()])\n",
"# 55000\n",
"trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)\n",
"train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)\n",
"testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)\n",
"test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)\n",
"\n",
"# net init\n",
"\n",
"net = Net()\n",
"net.to(device)\n",
"\n",
"# nn. MSE\n",
"loss = nn.CrossEntropyLoss()\n",
"optimizer = optim.SGD(net.parameters(), lr = 0.0001)\n",
"\n",
"def train():\n",
" acc_num=0\n",
" for epoch in range(max_epochs):\n",
" for i,(data, label) in enumerate(train_loader):\n",
" data = data.to(device)\n",
" label = label.to(device)\n",
" optimizer.zero_grad()\n",
" output = net(data)\n",
" Loss = loss(output, label)\n",
" Loss.backward()\n",
" optimizer.step()\n",
" \n",
" pred_class = torch.max(output, dim=1)[1]\n",
" acc_num += torch.eq(pred_class, label.to(device)).sum().item()\n",
" train_acc = acc_num / len(trainset)\n",
" \n",
" net.eval()\n",
" acc_num = 0.0\n",
" best_acc =0\n",
" with torch.no_grad():\n",
" for val_data in test_loader:\n",
" val_image, val_label = val_data\n",
" output = net(val_image.to(device))\n",
" predict_y = torch.max(output, dim=1)[1]\n",
" acc_num += torch.eq(predict_y, val_label.to(device)).sum().item()\n",
" val_acc = acc_num / len(testset)\n",
" print(train_acc, val_acc)\n",
" if val_acc > best_acc:\n",
" torch.save(net.state_dict(), './minst.pth')\n",
" best_acc = val_acc\n",
" \n",
" acc_num = 0\n",
" train_acc = 0\n",
" test_acc = 0\n",
" print('done')\n",
"\n",
"train()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c675ece6",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:.conda-pytorch_dl] *",
"language": "python",
"name": "conda-env-.conda-pytorch_dl-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment