DataXujing / YOLOv9

:fire: YOLOv9 paper解析,训练自己的数据集,TensorRT端到端部署, NCNN安卓手机部署
GNU General Public License v3.0
50 stars 10 forks source link
ncnn tensorrt yolov9

Official YOLOv9 训练自己的数据集并基于NVIDIA TensorRT和及安卓手机端部署

1.YOLOv9算法解读

YOLOv9的改进

PGI(可编程梯度信息)

PGI主要包括三个组成部分,即:

GELAN模块

YOLOv9提出了新网络架构——GELAN。GELAN通过结合两种神经网络架构,即结合用梯度路径规划(CSPNet)和(ELAN)设计了一种广义的高效层聚合网络(GELAN);GELAN综合考虑了轻量级、推理速度和准确度。

GELAN整体架构如上图所示。YOLOv9将ELAN的能力进行了泛化,原始ELAN仅使用卷积层的堆叠,而GELAN可以使用任何计算块作为基础Module。

损失函数与样本匹配

通过上图代码可以看到:

模型结构

2.构建自己的训练数据集训练YOLOv9

假设我们有NVIDIA的计算卡,同时配置好了YOLOv9运行需要的环境!

YOLOv9遵循YOLOv5-YOLOv8的训练数据构建方式,可以参考:https://github.com/DataXujing/YOLO-v5, 这里以肺炎X-ray数据集作为训练YOLOv9-c模型的数据集。

path: ./datasets/xray  # dataset root dir
train: images/train/   # train images (relative to 'path') 118287 images
val: images/val/  # val images (relative to 'path') 5000 images
test: images/test/   # 20288 of 40670 images, submit to https://competitions.codalab.org/competitions/20794

# Classes
names:
  0: pneumonia
# YOLOv9

# parameters
nc: 1  # number of classes
depth_multiple: 1.0  # model depth multiple
width_multiple: 1.0  # layer channel multiple
#activation: nn.LeakyReLU(0.1)
#activation: nn.ReLU()

# anchors
anchors: 3

......
wget https://github.com/WongKinYiu/yolov9/releases/download/v0.1/yolov9-c.pt
python3 train_dual.py  --weights=./pretrain/yolov9-c.pt --cfg=./models/detect/yolov9-c.yaml --data=./data/xray.yaml --epoch=100 --batch-size=16 --imgsz=640 --hyp=data/hyps/hyp.scratch-high.yaml

3.YOLOv9推理Demo

python inference.py
Pytorch Pytorch Pytorch

4.YOLOv9 端到端TensorRT加速推理C++实现

python export.py --data=./data/xray.yaml --weights=./runs/train/exp/weights/last.pt --opset=13 --include=onnx --simplify
python onnx_add_nms_op.py

在原onnx模型中插入EfficientNMS Plugin节点:

trtexec --onnx=last_nms.onnx --saveEngine=yolov9-c.plan --workspace=3000 --verbose

恭喜你,TensorRT序列化Engine成功!

我们的代码存放在tensorrt文件夹下(在TensorRT 8.2和TensorRT8.6测试),相同图片在TensorRT C++的推理结果基本一致

TensorRT TensorRT TensorRT

5.YOLOv9安卓手机部署

NCNN-FP32 NCNN-FP32 NCNN-FP32
NCNN-FP16 NCNN-FP16 NCNN-FP16

我们也实现了YOLOv9-c的ncnn下的int8量化,但是目前还存在问题:https://github.com/Tencent/ncnn/issues/5362

小米手机下的部署:

参考

  1. YOLOv9开源 | 架构图&模块改进&正负样本匹配&损失函数解读,5分钟即可理解YOLOv9

  2. YOLOv9: Learning What You Want to Learn Using Programmable Gradient Information

  3. https://github.com/WongKinYiu/yolov9