PyTorch 图像分类
依赖库的下载与安装
在终端中执行 pip install -r -requirements.txt 完成项目依赖库的安装
使用方式
数据集的准备
-
STL10 数据集
- 下载:STL-10 Dataset
- 存储位置:将下载后的数据集中 train_X.bin,train_y.bin,test_X.bin,test_y.bin 四个文件存入项目根目录下的 dataset\STL10 子目录内
-
自制数据集
训练模型
训练模型或进行模型预测时,设置 config.py 中的变量 CONTINUE_TRAIN 为 False ,若需要进行断点续训,设置该变量为 True
模型可以选择使用 ResNet50 与 AlexNet 两种网络之一进行训练,在 train.py 中设置训练模型的参数变量 model 来选择想要训练的模型
模型的训练重要超参数存储在 config.py 中,可根据实际需要进行修改
模型训练完成后参数的读取
模型训练完毕后,在项目文件根目录的 model_data 子目录下会生成两个文件,其中 last_model_state_dict.pth 存储了最后一次模型训练的学习率与模型参数信息,用于断点续训;另一个文件为 best_model_state_dict.pth 存储了模型训练过程中验证集的最高准确率所对应的模型参数信息,可以用来预测
测试模型
运行 test.py ,得到测试集预测准确率与混淆矩阵可视化图像
图片预测
将要预测的图片存储在项目根目录 imgs 文件夹下,运行 predict.py 中的 image_classification 函数,将图像名作为参数传递,即可得到预测结果