lxztju / pytorch_classification

利用pytorch实现图像分类的一个完整的代码,训练,预测,TTA,模型融合,模型部署,cnn提取特征,svm或者随机森林等进行分类,模型蒸馏,一个完整的代码
MIT License
1.38k stars 339 forks source link
cnn densenet deployment flask image-classification knn knowledge-distillation label-smoothing pytorch random-forest resnet resnext svm

简介

基于torchision实现的pytorch图像分类功能。

近期更新

习惯之前版本的请看v1版本的代码:V1版本

主要功能:

利用pytorch实现图像分类,基于torchision可以扩展使用densenet,resnext,mobilenet,efficientnet,swin transformer等图像分类网络

如果有用欢迎star

实现功能

运行环境

快速开始

数据集形式

数据集的组织形式,参考sample_files/imgs/listfile.txt

训练 测试

修改run.sh中的参数,直接运行run.sh即可运行

主要修改的参数:

OUTPUT_PATH 模型保存和log文件的路径

TRAIN_LIST 训练数据集的list文件
VAL_LIST  测试集合的list文件
model_name 默认是resnet50
lr 学习率
epochs 训练总的epoch
batch-size  batch的大小
j dataloader的num_workers的大小
num_classes 类别数

libtorch inference

代码存储在cpp_inference文件夹中。

  1. 利用cpp_inference/traced_model/trace_model.py将训练好的模型导出。

  2. 编译所需的opencv和libtorch代码到cpp_inference/third_party_library

  3. 编译

    sh compile.sh
  4. 可执行文件测试

    ./bin/imgCls imgpath