Commit 18e947f5 by 靓靓

upload files

parents
++ "a/1-AI\346\226\260\346\211\213\346\235\221/.gitkeep"
++ "a/1-AI\346\226\260\346\211\213\346\235\221/1.1-\347\216\257\345\242\203\351\205\215\347\275\256/1.1.1-\347\216\257\345\242\203&\345\214\205\347\232\204\344\273\213\347\273\215\345\222\214\345\256\211\350\243\205/.gitkeep"
++ "a/1-AI\346\226\260\346\211\213\346\235\221/1.1-\347\216\257\345\242\203\351\205\215\347\275\256/1.1.2-Windows\346\267\261\345\272\246\345\255\246\344\271\240\347\216\257\345\242\203\346\220\255\345\273\272/.gitkeep"
++ "a/1-AI\346\226\260\346\211\213\346\235\221/1.1-\347\216\257\345\242\203\351\205\215\347\275\256/1.1.2-Windows\346\267\261\345\272\246\345\255\246\344\271\240\347\216\257\345\242\203\346\220\255\345\273\272/windows\346\267\261\345\272\246\345\255\246\344\271\240\347\216\257\345\242\203\346\220\255\345\273\272.md"
++ "a/1-AI\346\226\260\346\211\213\346\235\221/1.1-\347\216\257\345\242\203\351\205\215\347\275\256/1.1.3-Linux\346\267\261\345\272\246\345\255\246\344\271\240\347\216\257\345\242\203\346\220\255\345\273\272/.gitkeep"
++ "a/1-AI\346\226\260\346\211\213\346\235\221/1.1-\347\216\257\345\242\203\351\205\215\347\275\256/1.1.4-\345\214\205\347\232\204\346\234\254\345\234\260\345\256\211\350\243\205/.gitkeep"
++ "a/1-AI\346\226\260\346\211\213\346\235\221/1.1-\347\216\257\345\242\203\351\205\215\347\275\256/1.1.5-\350\277\234\347\250\213\346\234\215\345\212\241\345\231\250\351\223\276\346\216\245/.gitkeep"
++ "a/1-AI\346\226\260\346\211\213\346\235\221/1.1-\347\216\257\345\242\203\351\205\215\347\275\256/1.1.6-Github\351\241\271\347\233\256\345\244\215\347\216\260/.gitkeep"
++ "a/2-AI\346\246\202\350\277\260/.gitkeep"
++ "a/2-AI\346\246\202\350\277\260/2.1-\347\245\236\347\273\217\347\275\221\347\273\234/2.1.1-\344\272\272\345\267\245\346\231\272\350\203\275\346\226\271\346\263\225\346\246\202\350\277\260/.gitkeep"
++ "a/2-AI\346\246\202\350\277\260/2.1-\347\245\236\347\273\217\347\275\221\347\273\234/2.1.2-\350\256\244\350\257\206\347\245\236\347\273\217\347\275\221\347\273\234/.gitkeep"
++ "a/2-AI\346\246\202\350\277\260/2.1-\347\245\236\347\273\217\347\275\221\347\273\234/2.1.3-\350\256\244\350\257\206\345\275\261\345\203\217\346\225\260\346\215\256/.gitkeep"
++ "a/3-AI\345\267\245\345\205\267\345\214\205/3.1-\347\274\226\347\250\213\350\257\255\350\250\200/3.1.1-Python\345\237\272\347\241\200\347\237\245\350\257\206/.gitkeep"
{
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## python基础入门"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"hello woorld\n"
]
}
],
"source": [
"print(\"hello woorld\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 2 3 4 5\n"
]
}
],
"source": [
"print(1,2,3,4,5)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 2\n"
]
}
],
"source": [
"acc=1\n",
"epoch=2\n",
"print(acc,epoch) # acc为深度学习模型训练的精度指标,epoch为轮数"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 2 3 4 5\n",
"1->2->3->4->5\n"
]
}
],
"source": [
"print(1,2,3,4,5)\n",
"print(1,2,3,4,5,sep=\"->\") # sep指定输出的分隔"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 2 3 4 5\t1 2 3 4 5\n"
]
}
],
"source": [
"print(1,2,3,4,5,end=\"\\t\") # \\t为制表符,默认一个tab\n",
"print(1,2,3,4,5)\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 2 3 4 5\n",
"1 2 3 4 5\n"
]
}
],
"source": [
"print(1,2,3,4,5) # \\n为换行\n",
"print(1,2,3,4,5)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"f = open(\"test.txt\",\"a\") # 若没有此txt文件则先创建,然后将内容追加在末尾\n",
"print(\"\\n\",file=f)\n",
"print(\"中文报错,print后面括号为中文()\",file=f)\n",
"f.close()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.30000000000000004\n"
]
}
],
"source": [
"print(0.1+0.2)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1690473024512\n",
"1690473024512\n"
]
}
],
"source": [
"list1 = [1,2,3,5]\n",
"print(id(list1))\n",
"list1[1] = 5\n",
"print(id(list1))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 格式化输出"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"我叫qianyu,今年18岁啦\n"
]
}
],
"source": [
"print(\"我叫%s,今年%d岁啦\" % (\"qianyu\",18)) # %s为字符串 %d为整型"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"学校的名称是点头教育,学习的网站是wwww.diantouedu.cn\n"
]
}
],
"source": [
"print(\"学校的名称是{},学习的网站是{}\".format(\"点头教育\",\"wwww.diantouedu.cn\"))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"我叫做qianyu,今年19岁啦\n",
"我叫做qianyu,今年19岁啦\n",
"我叫做qianyu,今年19岁啦\n",
"我叫做qianyu,今年19岁啦\n"
]
}
],
"source": [
"# f表达式格式化输出\n",
"name = \"qianyu\"\n",
"age = 19\n",
"print(f\"我叫做{name},今年{age}岁啦\")\n",
"print(f\"我叫做{name},今年{age}岁啦\")\n",
"print(f\"我叫做{name},今年{age}岁啦\")\n",
"print(f\"我叫做{name},今年{age}岁啦\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"我叫做qianyu,今年19岁啦\n",
"我叫做qianyu,今年19岁啦\n",
"我叫做qianyu,今年19岁啦\n",
"我叫做qianyu,今年19岁啦\n"
]
}
],
"source": [
"print(\"我叫做qianyu,今年19岁啦\")\n",
"print(\"我叫做qianyu,今年19岁啦\")\n",
"print(\"我叫做qianyu,今年19岁啦\")\n",
"print(\"我叫做qianyu,今年19岁啦\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"训练轮数为1\n"
]
}
],
"source": [
"epoch = 1\n",
"print(f\"训练轮数为{epoch}\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['False', 'None', 'True', '__peg_parser__', 'and', 'as', 'assert', 'async', 'await', 'break', 'class', 'continue', 'def', 'del', 'elif', 'else', 'except', 'finally', 'for', 'from', 'global', 'if', 'import', 'in', 'is', 'lambda', 'nonlocal', 'not', 'or', 'pass', 'raise', 'return', 'try', 'while', 'with', 'yield']\n"
]
}
],
"source": [
"import keyword\n",
"print(keyword.kwlist)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 变量"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2437147683120\n",
"2437147683152\n"
]
}
],
"source": [
"a = 1\n",
"print(id(a)) \n",
"# print(type(a)) ctrl+/为整体注释\n",
"a = 2\n",
"print(id(a))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 输入函数"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"hello\n",
"<class 'str'>\n"
]
}
],
"source": [
"age = input() # 无论输入的是什么 都会被转化为字符串类型\n",
"print(age)\n",
"print(type(age))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "test",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
# Python基础编程
# Python基础编程
这个仓库包含了关于第一和第二节课的Python基础编程代码。
## 第一节课
- 代码文件: `python基础入门.ipynb`
- 文本文件: `test.txt` (包含一些常见的报错,是第一节课中用到的文本文件)
## 第二节课
- 代码文件: `六大数据类型2.ipynb`(六大数据类型.ipynb为往期代码,可做参考)
\ No newline at end of file
# Python基础报错例子
# Python基础报错例子
SyntaxError(语法错误):
print "Hello, World!"
# 缺少括号,应该是 print("Hello, World!")
IndentationError(缩进错误):
def my_function():
print("Hello, World!")
# 函数体缩进不正确
NameError(名称错误):
print(x)
# x未定义
TypeError(类型错误):
x = "5"
y = 2
z = x + y
# 字符串和整数不能直接相加
IndexError(索引错误):
my_list = [1, 2, 3]
print(my_list[3])
# 索引超出了列表范围
KeyError(键错误):
my_dict = {"name": "Alice", "age": 30}
print(my_dict["gender"])
# 字典中没有"gender"这个键
ValueError(数值错误):
int("abc")
# 无法将非数字字符串转换为整数
FileNotFoundError(文件未找到错误):
with open("nonexistent_file.txt", "r") as f:
content = f.read()
# 文件不存在
ImportError(导入错误):
import nonexistent_module
# 未安装或导入的模块不存在
ZeroDivisionError(除零错误):
result = 10 / 0
# 除数为零
AttributeError(属性错误):
x = 5
print(x.append(1))
# 整数对象没有append()方法
KeyboardInterrupt(键盘中断):
while True:
pass
# 执行这段代码后,按下Ctrl+C将会中断程序执行hello world
中文报错,print后面括号为中文()
++ "a/3-AI\345\267\245\345\205\267\345\214\205/3.1-\347\274\226\347\250\213\350\257\255\350\250\200/3.1.2-Python\345\270\270\347\224\250\345\207\275\346\225\260/.gitkeep"
{
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1\n",
"2\n"
]
}
],
"source": [
"a = 1\n",
"b = 2\n",
"print(a)\n",
"print(b)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"b的值为1\n"
]
}
],
"source": [
"a = -1\n",
"if a>=0:\n",
" print(\"a的值是非负\")\n",
"\n",
"\n",
"b = 1\n",
"print(f\"b的值为{b}\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"a是非负数,a的值为0\n"
]
}
],
"source": [
"a = 0\n",
"if a>=0:\n",
" print(f\"a是非负数,a的值为{a}\")\n",
"else:\n",
" print(f\"a是负数,a的值为{a}\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"score = 100\n",
"if score>=90 and score<=100:\n",
" print(\"该学生的成绩很不错,为A\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "test",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
### Python进阶知识
### Python进阶知识
这里是Python部分进阶的知识,主要包括以下内容:
#### 1. 三大程序结构
- **顺序:** 指程序从上到下按照顺序执行的结构。
- **分支:** 涉及条件判断,根据条件不同执行不同的代码块。
- **循环:** 允许代码块多次执行,直到满足退出条件为止。
#### 2. 函数的定义和使用
函数是一段可重复使用的代码块,可以接受参数并返回值。
#### 3. 经典编程案例
- **递归求n阶乘:** 通过递归方式计算给定数的阶乘。
- **求解水仙花数:** 寻找指定范围内的水仙花数,即每位数字的立方和等于该数本身的数。
++ "a/3-AI\345\267\245\345\205\267\345\214\205/3.1-\347\274\226\347\250\213\350\257\255\350\250\200/3.1.3-Python\351\235\242\345\220\221\345\257\271\350\261\241\347\274\226\347\250\213/.gitkeep"
## 本章节讲解的是 Python 面向对象编程的相关内容,包括类和对象的定义和使用、封装、继承、多态。在学习本章之前需要先看环境配置、Python基础、进阶课程
## 本章节讲解的是 Python 面向对象编程的相关内容,包括类和对象的定义和使用、封装、继承、多态。在学习本章之前需要先看环境配置、Python基础、进阶课程
[
[
{
"id": "001",
"name": "点头教育",
"url": "www.diantouedu.cn",
"age": 10
},
{
"id": "002",
"name": "Google",
"url": "www.google.com",
"age": 100
},
{
"id": "003",
"name": "淘宝",
"url": "www.taobao.com",
"age": 50
}
]
\ No newline at end of file
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
++ "a/3-AI\345\267\245\345\205\267\345\214\205/3.4-\346\267\261\345\272\246\345\255\246\344\271\240\346\241\206\346\236\266/3.4.1-PyTorch\345\237\272\346\234\254\346\225\260\346\215\256\347\261\273\345\236\213/.gitkeep"
++ "a/3-AI\345\267\245\345\205\267\345\214\205/3.4-\346\267\261\345\272\246\345\255\246\344\271\240\346\241\206\346\236\266/3.4.2-PyTorch\346\241\206\346\236\266\347\232\204\344\275\277\347\224\250/.gitkeep"
{
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# FashionMNIST 数据集介绍\n",
"\n",
"`FashionMNIST` 数据集是由 Zalando 提供的一个包含时尚商品图像的数据集,旨在为机器学习和计算机视觉任务提供一个简单的基准数据集。与经典的 MNIST 数据集类似,FashionMNIST 包含 10 个类别的图像,但这些图像是时尚产品(如衣物、鞋子等),而不是手写数字。每张图像的大小为 28x28 像素,灰度图像,共有 70,000 张样本(其中 60,000 张用于训练,10,000 张用于测试)。\n",
"\n",
"## 类别标签及其中文翻译\n",
"\n",
"以下是 `FashionMNIST` 数据集中的 10 个类别标签及其对应的中文翻译:\n",
"\n",
"1. **T-shirt/top** - T恤/上衣\n",
"2. **Trouser** - 长裤\n",
"3. **Pullover** - 套头衫\n",
"4. **Dress** - 连衣裙\n",
"5. **Coat** - 外套\n",
"6. **Sandal** - 凉鞋\n",
"7. **Shirt** - 衬衫\n",
"8. **Sneaker** - 运动鞋\n",
"9. **Bag** - 包\n",
"10. **Ankle boot** - 短靴\n",
"\n",
"这些标签代表了数据集中包含的各种时尚商品类别,每个类别的图像都是通过计算机视觉方法进行分类的目标。\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 导入必要的库\n",
"import os # 用于操作文件和目录\n",
"import torch # 用于PyTorch相关操作,尽管在本代码中未使用(可能未来有用)\n",
"from torchvision import datasets, transforms # 用于数据集和数据变换\n",
"from PIL import Image # 用于图像操作,特别是将Tensor转换为图像并保存\n",
"from torchvision.datasets import FashionMNIST # 用于加载FashionMNIST数据集\n",
"\n",
"# 定义数据集保存路径\n",
"data_dir = './FashionMNIST_images' # 定义数据集的根目录\n",
"train_dir = os.path.join(data_dir, 'train') # 训练集保存路径\n",
"test_dir = os.path.join(data_dir, 'test') # 测试集保存路径\n",
"\n",
"# 定义分类标签,FashionMNIST数据集有10个类别\n",
"class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']\n",
"\n",
"# 创建文件夹结构\n",
"def create_folders(base_dir):\n",
" \"\"\"\n",
" 创建用于保存数据集的文件夹结构。将数据集分为训练集和测试集,并根据类别在每个集下创建子文件夹。\n",
" \"\"\"\n",
" for split in ['train', 'test']: # 对训练集和测试集分别操作\n",
" split_dir = os.path.join(base_dir, split) # 拼接路径\n",
" os.makedirs(split_dir, exist_ok=True) # 如果文件夹不存在则创建,存在则不做任何操作\n",
" for class_name in class_names: # 为每个类别创建一个文件夹\n",
" class_dir = os.path.join(split_dir, class_name) # 拼接类别文件夹路径\n",
" os.makedirs(class_dir, exist_ok=True) # 创建类别文件夹\n",
"\n",
"# 保存图像到文件夹\n",
"def save_images(dataset, split='train'):\n",
" \"\"\"\n",
" 将数据集中的图像按类别保存到相应的文件夹。\n",
" \"\"\"\n",
" for i, (image, label) in enumerate(dataset): # 遍历数据集中的每个图像和标签\n",
" class_name = class_names[label] # 获取类别名称\n",
" folder_path = os.path.join(data_dir, split, class_name) # 拼接类别文件夹路径\n",
" \n",
" # 保存图片\n",
" image = transforms.ToPILImage()(image) # 将Tensor格式的图像转换为PIL图像,以便保存\n",
" image.save(os.path.join(folder_path, f'{split}_{i}.png')) # 保存为PNG格式,文件名为 \"split_i.png\"\n",
"\n",
"# 下载FashionMNIST数据集并应用转换\n",
"transform = transforms.Compose([transforms.ToTensor()]) # 将图片转换为Tensor格式\n",
"train_dataset = FashionMNIST(root='./data', train=True, download=True, transform=transform) # 下载并加载训练集\n",
"test_dataset = FashionMNIST(root='./data', train=False, download=True, transform=transform) # 下载并加载测试集\n",
"\n",
"# 创建用于保存数据集的文件夹结构\n",
"create_folders(data_dir)\n",
"\n",
"# 将训练和测试数据保存到文件夹中\n",
"save_images(train_dataset, 'train') # 保存训练集图像\n",
"save_images(test_dataset, 'test') # 保存测试集图像\n",
"\n",
"# 打印保存成功的信息\n",
"print(\"Images saved successfully.\")\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "test",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
# 项目名称
# 项目名称
该项目包含FashionMNIST数据集的转换、预处理、模型训练及推理。
## 项目文件说明
- `data_conv.ipynb`: FashionMNIST数据集转换脚本。该脚本用于将从`torchvision`中下载的FashionMNIST数据集转换为图片格式并保存。
- `train.ipynb`: 数据预处理、模型训练及推理的完整流程脚本。该notebook包含了从数据加载到最终推理的所有步骤。
- `train_0.py`: 单独的训练脚本(可选),可独立运行以进行模型训练。
## 数据集
项目中使用的FashionMNIST数据集,可以通过`torchvision`下载。转换后的图片格式数据将用于模型训练和推理。
## 项目使用说明
1. 下载数据集转换脚本并运行`data_conv.ipynb`,将FashionMNIST数据集转换为图片格式。
2. 运行`train.ipynb`,进行数据预处理、模型训练及推理。
3. (可选)运行`train_0.py`,单独进行模型训练。
## 数据集链接
链接: [https://pan.baidu.com/s/1uxkvg3NWyollRVqrfkAXcQ?pwd=8888](https://pan.baidu.com/s/1uxkvg3NWyollRVqrfkAXcQ?pwd=8888)
提取码: 8888
import torch
import torch
import torch.nn as nn
import torch.optim as optim # 导入优化器
from torchvision import datasets, transforms # 导入数据集和数据预处理库
from torch.utils.data import DataLoader # 数据加载库
# 设置随机种子
torch.manual_seed(21)
# 定义数据预处理
transform = transforms.Compose([
transforms.ToTensor(), # 将图像转换为张量
transforms.Normalize((0.5), (0.5)) # 标准化图像数据 灰度图,只需要一个0.5 -1 - 1
])
# 加载FashionMNIST数据集
train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform) # 下载训练集
test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform) # 下载测试集
# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) # 对训练集进行打包, 一批次64个图像塞入神经网络训练
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False) # 对测试集进行打包
# 定义神经网络模型
class QYNN(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(28*28, 128) # 定义第一个全连接层 隐藏层的神经元个数为128
self.fc2 = nn.Linear(128, 10) # 定义第二个全连接层 输出神经元个数 10 因为我们需要做10分类
def forward(self, x): # 前向传播
x = torch.flatten(x, start_dim=1) # 展平数据,方便进行全连接
x = torch.relu(self.fc1(x)) # 非线性
x = self.fc2(x) # 十分类 [0.1,0.2,0.5,0.2,0,0,0,0,0,0]
return x
# 初始化模型
model = QYNN()
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss() # 交叉熵
optimizer = optim.SGD(model.parameters(), lr=0.01) # lr 学习率 用来调整模型收敛速度 0.1
# 训练模型
epochs = 10
for epoch in range(epochs): # 0-9
running_loss = 0.0 # 定义初始loss为0
for inputs, labels in train_loader:
optimizer.zero_grad() # 梯度清零
outputs = model(inputs) # 将图片塞进网络训练获得 输出
loss = criterion(outputs, labels) # 根据输出和标签做对比计算损失
loss.backward() # 反向传播
optimizer.step() # 更新参数
running_loss += loss.item() # loss值累加
print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(train_loader)}")
# 测试模型
correct = 0 # 正确的数量
total = 0 # 样本总数
with torch.no_grad(): # 不用进行梯度计算
for inputs, labels in test_loader:
# print(labels.shape)
outputs = model(inputs) # [0.1,0.2,0.5,0.2,0,0,0,0,0,0] 2
# print(outputs.shape)
_, predicted = torch.max(outputs, 1) # _取到的最大值,可以不要, 我们需要的是最大值对应的索引 也就是label(predicted)
total += labels.size(0) # 获取当前批次样本数量
correct += (predicted == labels).sum().item() # 对预测对的值进行累加
print(f"Accuracy on test set: {correct/total:.2%}")
++ "a/3-AI\345\267\245\345\205\267\345\214\205/3.4-\346\267\261\345\272\246\345\255\246\344\271\240\346\241\206\346\236\266/3.4.3-TensorFlow\346\241\206\346\236\266\347\232\204\344\275\277\347\224\250/.gitkeep"
This source diff could not be displayed because it is too large. You can view the blob instead.
# Default ignored files
# Default ignored files
/shelf/
/workspace.xml
# Editor-based HTTP Client requests
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml
<?xml version="1.0" encoding="UTF-8"?>
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="Python 3.10" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="PublishConfigData" remoteFilesAllowedToDisappearOnAutoupload="false">
<serverData>
<paths name="root@region-45.autodl.pro:28338 password">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="root@region-45.autodl.pro:28338 password (2)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
</serverData>
</component>
</project>
\ No newline at end of file
<component name="InspectionProjectProfileManager">
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
<option name="ignoredPackages">
<value>
<list size="11">
<item index="0" class="java.lang.String" itemvalue="tensorboard" />
<item index="1" class="java.lang.String" itemvalue="opencv-python" />
<item index="2" class="java.lang.String" itemvalue="torch" />
<item index="3" class="java.lang.String" itemvalue="numpy" />
<item index="4" class="java.lang.String" itemvalue="torchvision" />
<item index="5" class="java.lang.String" itemvalue="tqdm" />
<item index="6" class="java.lang.String" itemvalue="scipy" />
<item index="7" class="java.lang.String" itemvalue="h5py" />
<item index="8" class="java.lang.String" itemvalue="matplotlib" />
<item index="9" class="java.lang.String" itemvalue="opencv_python" />
<item index="10" class="java.lang.String" itemvalue="Pillow" />
</list>
</value>
</option>
</inspection_tool>
</profile>
</component>
\ No newline at end of file
<component name="InspectionProjectProfileManager">
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10" project-jdk-type="Python SDK" />
</project>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/StructuredData.iml" filepath="$PROJECT_DIR$/.idea/StructuredData.iml" />
</modules>
</component>
</project>
\ No newline at end of file
This source diff could not be displayed because it is too large. You can view the blob instead.
PassengerId,Survived
PassengerId,Survived
892,0
893,1
894,0
895,0
896,1
897,0
898,1
899,0
900,1
901,0
902,0
903,0
904,1
905,0
906,1
907,1
908,0
909,0
910,1
911,1
912,0
913,0
914,1
915,0
916,1
917,0
918,1
919,0
920,0
921,0
922,0
923,0
924,1
925,1
926,0
927,0
928,1
929,1
930,0
931,0
932,0
933,0
934,0
935,1
936,1
937,0
938,0
939,0
940,1
941,1
942,0
943,0
944,1
945,1
946,0
947,0
948,0
949,0
950,0
951,1
952,0
953,0
954,0
955,1
956,0
957,1
958,1
959,0
960,0
961,1
962,1
963,0
964,1
965,0
966,1
967,0
968,0
969,1
970,0
971,1
972,0
973,0
974,0
975,0
976,0
977,0
978,1
979,1
980,1
981,0
982,1
983,0
984,1
985,0
986,0
987,0
988,1
989,0
990,1
991,0
992,1
993,0
994,0
995,0
996,1
997,0
998,0
999,0
1000,0
1001,0
1002,0
1003,1
1004,1
1005,1
1006,1
1007,0
1008,0
1009,1
1010,0
1011,1
1012,1
1013,0
1014,1
1015,0
1016,0
1017,1
1018,0
1019,1
1020,0
1021,0
1022,0
1023,0
1024,1
1025,0
1026,0
1027,0
1028,0
1029,0
1030,1
1031,0
1032,1
1033,1
1034,0
1035,0
1036,0
1037,0
1038,0
1039,0
1040,0
1041,0
1042,1
1043,0
1044,0
1045,1
1046,0
1047,0
1048,1
1049,1
1050,0
1051,1
1052,1
1053,0
1054,1
1055,0
1056,0
1057,1
1058,0
1059,0
1060,1
1061,1
1062,0
1063,0
1064,0
1065,0
1066,0
1067,1
1068,1
1069,0
1070,1
1071,1
1072,0
1073,0
1074,1
1075,0
1076,1
1077,0
1078,1
1079,0
1080,1
1081,0
1082,0
1083,0
1084,0
1085,0
1086,0
1087,0
1088,0
1089,1
1090,0
1091,1
1092,1
1093,0
1094,0
1095,1
1096,0
1097,0
1098,1
1099,0
1100,1
1101,0
1102,0
1103,0
1104,0
1105,1
1106,1
1107,0
1108,1
1109,0
1110,1
1111,0
1112,1
1113,0
1114,1
1115,0
1116,1
1117,1
1118,0
1119,1
1120,0
1121,0
1122,0
1123,1
1124,0
1125,0
1126,0
1127,0
1128,0
1129,0
1130,1
1131,1
1132,1
1133,1
1134,0
1135,0
1136,0
1137,0
1138,1
1139,0
1140,1
1141,1
1142,1
1143,0
1144,0
1145,0
1146,0
1147,0
1148,0
1149,0
1150,1
1151,0
1152,0
1153,0
1154,1
1155,1
1156,0
1157,0
1158,0
1159,0
1160,1
1161,0
1162,0
1163,0
1164,1
1165,1
1166,0
1167,1
1168,0
1169,0
1170,0
1171,0
1172,1
1173,0
1174,1
1175,1
1176,1
1177,0
1178,0
1179,0
1180,0
1181,0
1182,0
1183,1
1184,0
1185,0
1186,0
1187,0
1188,1
1189,0
1190,0
1191,0
1192,0
1193,0
1194,0
1195,0
1196,1
1197,1
1198,0
1199,0
1200,0
1201,1
1202,0
1203,0
1204,0
1205,1
1206,1
1207,1
1208,0
1209,0
1210,0
1211,0
1212,0
1213,0
1214,0
1215,0
1216,1
1217,0
1218,1
1219,0
1220,0
1221,0
1222,1
1223,0
1224,0
1225,1
1226,0
1227,0
1228,0
1229,0
1230,0
1231,0
1232,0
1233,0
1234,0
1235,1
1236,0
1237,1
1238,0
1239,1
1240,0
1241,1
1242,1
1243,0
1244,0
1245,0
1246,1
1247,0
1248,1
1249,0
1250,0
1251,1
1252,0
1253,1
1254,1
1255,0
1256,1
1257,1
1258,0
1259,1
1260,1
1261,0
1262,0
1263,1
1264,0
1265,0
1266,1
1267,1
1268,1
1269,0
1270,0
1271,0
1272,0
1273,0
1274,1
1275,1
1276,0
1277,1
1278,0
1279,0
1280,0
1281,0
1282,0
1283,1
1284,0
1285,0
1286,0
1287,1
1288,0
1289,1
1290,0
1291,0
1292,1
1293,0
1294,1
1295,0
1296,0
1297,0
1298,0
1299,0
1300,1
1301,1
1302,1
1303,1
1304,1
1305,0
1306,1
1307,0
1308,0
1309,0
This source diff could not be displayed because it is too large. You can view the blob instead.
# Default ignored files
# Default ignored files
/shelf/
/workspace.xml
# Editor-based HTTP Client requests
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml
<?xml version="1.0" encoding="UTF-8"?>
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="Python 3.10" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="PublishConfigData" remoteFilesAllowedToDisappearOnAutoupload="false">
<serverData>
<paths name="root@region-45.autodl.pro:28338 password">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="root@region-45.autodl.pro:28338 password (2)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
</serverData>
</component>
</project>
\ No newline at end of file
<component name="InspectionProjectProfileManager">
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
<option name="ignoredPackages">
<value>
<list size="11">
<item index="0" class="java.lang.String" itemvalue="tensorboard" />
<item index="1" class="java.lang.String" itemvalue="opencv-python" />
<item index="2" class="java.lang.String" itemvalue="torch" />
<item index="3" class="java.lang.String" itemvalue="numpy" />
<item index="4" class="java.lang.String" itemvalue="torchvision" />
<item index="5" class="java.lang.String" itemvalue="tqdm" />
<item index="6" class="java.lang.String" itemvalue="scipy" />
<item index="7" class="java.lang.String" itemvalue="h5py" />
<item index="8" class="java.lang.String" itemvalue="matplotlib" />
<item index="9" class="java.lang.String" itemvalue="opencv_python" />
<item index="10" class="java.lang.String" itemvalue="Pillow" />
</list>
</value>
</option>
</inspection_tool>
</profile>
</component>
\ No newline at end of file
<component name="InspectionProjectProfileManager">
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10" project-jdk-type="Python SDK" />
</project>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/ImageData.iml" filepath="$PROJECT_DIR$/.idea/ImageData.iml" />
</modules>
</component>
</project>
\ No newline at end of file
# This is a sample Python script.
# This is a sample Python script.
# Press Shift+F10 to execute it or replace it with your code.
# Press Double Shift to search everywhere for classes, files, tool windows, actions, and settings.
def print_hi(name):
# Use a breakpoint in the code line below to debug your script.
print(f'Hi, {name}') # Press Ctrl+F8 to toggle the breakpoint.
# Press the green button in the gutter to run the script.
if __name__ == '__main__':
print_hi('PyCharm')
# See PyCharm help at https://www.jetbrains.com/help/pycharm/
This source diff could not be displayed because it is too large. You can view the blob instead.
下载链接<br>
下载链接<br>
链接:https://pan.baidu.com/s/1Kmue_npXNeHLU1z2XuB68Q?pwd=wgxq <br>
提取码:wgxq<br>
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment