基于torchision实现的pytorch图像分类功能。
2022.11.05更新
2022.10.29更新,进行代码重构,基本的功能基本一致。
习惯之前版本的请看v1版本的代码:V1版本。
主要功能:
利用pytorch实现图像分类,基于torchision可以扩展使用densenet,resnext,mobilenet,efficientnet,swin transformer等图像分类网络
如果有用欢迎star
数据集的组织形式,参考sample_files/imgs/listfile.txt
修改run.sh
中的参数,直接运行run.sh即可运行
主要修改的参数:
OUTPUT_PATH 模型保存和log文件的路径
TRAIN_LIST 训练数据集的list文件
VAL_LIST 测试集合的list文件
model_name 默认是resnet50
lr 学习率
epochs 训练总的epoch
batch-size batch的大小
j dataloader的num_workers的大小
num_classes 类别数
代码存储在cpp_inference
文件夹中。
利用cpp_inference/traced_model/trace_model.py将训练好的模型导出。
编译所需的opencv和libtorch代码到cpp_inference/third_party_library
编译
sh compile.sh
可执行文件测试
./bin/imgCls imgpath