Open levipereira opened 7 months ago
Updated. Thanks a lot.
TensorRT version: 10.0.0
YOLOv9-C
Eval Model | AP | AP50 | Precision | Recall |
---|---|---|---|---|
Origin (Pytorch) | 0.529 | 0.699 | 0.743 | 0.634 |
INT8 (TensorRT) | 0.527 | 0.695 | 0.746 | 0.627 |
Eval Model | AP | AP50 | Precision | Recall |
---|---|---|---|---|
INT8 (TensorRT) vs Origin (Pytorch) | ||||
-0.002 | -0.004 | +0.003 | -0.007 |
GPU | |
---|---|
Device | NVIDIA GeForce RTX 4090 |
Compute Capability | 8.9 |
SMs | 128 |
Device Global Memory | 24207 MiB |
Application Compute Clock Rate | 2.58 GHz |
Application Memory Clock Rate | 10.501 GHz |
Model Name | Batch Size | Latency (99%) | Throughput (qps) | Total Inferences (IPS) |
---|---|---|---|---|
(FP16) | 1 | 1.25 ms | 803 | 803 |
4 | 3.37 ms | 300 | 1200 | |
8 | 6.6 ms | 153 | 1224 | |
12 | 10 ms | 99 | 1188 | |
INT8 | 1 | 0.99 ms | 1006 | 1006 |
4 | 2.12 ms | 473 | 1892 | |
8 | 3.84 ms | 261 | 2088 | |
12 | 5.59 ms | 178 | 2136 |
Model Name | Batch Size | Latency (99%) | Throughput (qps) | Total Inferences |
---|---|---|---|---|
INT8 vs FP16 | ||||
1 | -20.8% | +25.2% | +25.2% | |
4 | -37.1% | +57.7% | +57.7% | |
8 | -41.1% | +70.6% | +70.6% | |
12 | -46.9% | +79.8% | +78.9% |
@WongKinYiu Do you happen to have a YOLOv9-C or YOLOv9-E model trained with ReLU or ReLU6 activation functions? I need it for performance testing with quantization. If available and you could share it, it would greatly help me.
Sorry for late reply, yolov9-relu.pt is here. Not yet re-parameterized.
@WongKinYiu Thank you for providing the weights file. As I suspected, the ReLU activation function delivers much better performance (latency) compared to SiLU. Depending on the scenario, it might be worth sacrificing a bit of accuracy for the sake of latency.
The current results have been quite satisfactory, achieving a minimum latency value of 0.84ms
.
My next goal is to test with the ReLU6 function.
Below are the tables of the results:
TensorRT version: 10.0.0
GPU | |
---|---|
Device | NVIDIA GeForce RTX 4090 |
Compute Capability | 8.9 |
SMs | 128 |
Device Global Memory | 24207 MiB |
Application Compute Clock Rate | 2.58 GHz |
Application Memory Clock Rate | 10.501 GHz |
Eval Model | AP | AP50 | Precision | Recall |
---|---|---|---|---|
Origin (PyTorch) | 0.519 | 0.69 | 0.719 | 0.629 |
INT8 (PyTorch) | 0.518 | 0.69 | 0.725 | 0.623 |
INT8 (TensorRT) | 0.517 | 0.685 | 0.723 | 0.623 |
Eval Model | AP Diff | AP50 Diff | Precision Diff | Recall Diff |
---|---|---|---|---|
INT8 (TensorRT) vs Origin (PyTorch) | -0.002 | -0.005 | +0.004 | -0.006 |
Model Name | Batch Size | Latency (99%) | Throughput (qps) | Total Inferences (IPS) |
---|---|---|---|---|
YOLOv9-ReLU (FP16) | 1 | 1.15 ms | 868 | 868 |
12 | 8.81 ms | 115 | 1380 | |
YOLOv9-ReLU (INT8) | 1 | 0.84 ms | 1186 | 1186 |
12 | 4.59 ms | 218 | 2616 |
Model Name | Batch Size | Latency (99%) Diff | Throughput (qps) Diff | Total Inferences (IPS) Diff |
---|---|---|---|---|
(INT8) vs (FP16) | ||||
1 | -27.0% | +36.5% | +36.5% | |
12 | -47.9% | +89.6% | +89.6% |
can we infer the Pytorch int8 model? what is the benchmark report pytorch int8 vs trt int8?
@levipereira
Could you help for examine the latency/throughput without NMS? Thanks in advance.
@WongKinYiu
These tests were performed without NMS.
Below is a table with additional tests.
https://github.com/levipereira/yolov9-qat?tab=readme-ov-file#latencythroughput
Thanks!
@levipereira
Excuse me, I would like to borrow your time again. Could you please help me for examine the latency/throughput of following models? yolov9-c-coarse.pt, yolov9-c-fine.pt, lh-yolov9-c-coarse.pt, lh-yolov9-c-fine.pt. Thanks very much.
@WongKinYiu
Precision | Batch | Latency (percentile 99%) | Throughput (qps) | Total Throughput (IPS) |
---|---|---|---|---|
FP16 | 1 | 1.21 ms | 824 | 824 |
FP16 | 8 | 6.18 ms | 164 | 1312 |
INT8 | 1 | 0.95 ms | 1051 | 1051 |
INT8 | 8 | 3.55 ms | 281 | 2248 |
INT8 | 12 | 5.17 ms | 195 | 2340 |
Precision | Batch | Latency (percentile 99%) | Throughput (qps) | Total Throughput (IPS) |
---|---|---|---|---|
FP16 | 1 | 1.21 ms | 822 | 822 |
FP16 | 8 | 6.22 ms | 162 | 1296 |
INT8 | 1 | 0.95 ms | 1050 | 1050 |
INT8 | 8 | 3.56 ms | 281 | 2248 |
INT8 | 12 | 5.18 ms | 193 | 2316 |
Precision | Batch | Latency (percentile 99%) | Throughput (qps) | Total Throughput (IPS) |
---|---|---|---|---|
FP16 | 1 | 1.25 ms | 800 | 800 |
FP16 | 8 | 6.65 ms | 152 | 1216 |
INT8 | 1 | 0.97 ms | 1033 | 1033 |
INT8 | 8 | 3.67 ms | 272 | 2176 |
INT8 | 12 | 5.32 ms | 189 | 2268 |
Precision | Batch | Latency (percentile 99%) | Throughput (qps) | Total Throughput (IPS) |
---|---|---|---|---|
FP16 | 1 | 1.25 ms | 804 | 804 |
FP16 | 8 | 6.62 ms | 152 | 1216 |
INT8 | 1 | 0.98 ms | 1026 | 1026 |
INT8 | 8 | 3.68 ms | 271 | 2168 |
INT8 | 12 | 5.34 ms | 189 | 2268 |
Eval Model | AP | AP50 | Precision | Recall |
---|---|---|---|---|
Origin | 0.527 | 0.699 | 0.74 | 0.633 |
QAT-TRT | 0.524 | 0.692 | 0.723 | 0.638 |
Eval Model | AP | AP50 | Precision | Recall |
---|---|---|---|---|
Origin | 0.523 | 0.699 | 0.738 | 0.63 |
QAT-TRT | 0.522 | 0.693 | 0.743 | 0.622 |
Eval Model | AP | AP50 | Precision | Recall |
---|---|---|---|---|
Origin | 0.525 | 0.701 | 0.723 | 0.639 |
QAT-TRT | 0.524 | 0.695 | 0.733 | 0.629 |
Eval Model | AP | AP50 | Precision | Recall |
---|---|---|---|---|
Origin | 0.527 | 0.699 | 0.724 | 0.642 |
QAT-TRT | 0.524 | 0.693 | 0.726 | 0.631 |
@WongKinYiu QAT models have shown substantial improvements in latency and performance with minimal accuracy loss. It would be advantageous for the community to begin incorporating support for QAT in the codebase, allowing it to be activated post-training. https://docs.nvidia.com/deeplearning/tensorrt/pytorch-quantization-toolkit/docs/index.html#further-optimization
Thanks a lot. It seems fine branch lost less accuracy than coarse branch after QAT. The coarse-to-fine models will finish training tomorrow. I will update finial weights soon.
By the way, fine branch need not nms for post-processing. You could also estimate accuracy without nms for fine branch models.
I am currently using the YOLOv9 code found at this link: https://github.com/levipereira/yolov9-qat/blob/master/val_trt.py#L249-L290. If you already have the corresponding code to evaluate without NMS, I would greatly appreciate it.
Currently I just remove nms part of non_max_suppresion
in general.py
to implement no_max_suppression
.
def no_max_suppression(
prediction,
conf_thres=0.25,
iou_thres=0.45,
classes=None,
agnostic=False,
multi_label=False,
labels=(),
max_det=300,
nm=0, # number of masks
):
"""Non-Maximum Suppression (NMS) on inference results to reject overlapping detections
Returns:
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
"""
if isinstance(prediction, (list, tuple)): # YOLO model in validation model, output = (inference_out, loss_out)
prediction = prediction[0] # select only inference output
device = prediction.device
mps = 'mps' in device.type # Apple MPS
if mps: # MPS not fully supported yet, convert tensors to CPU before NMS
prediction = prediction.cpu()
bs = prediction.shape[0] # batch size
nc = prediction.shape[1] - nm - 4 # number of classes
mi = 4 + nc # mask start index
xc = prediction[:, 4:mi].amax(1) > conf_thres # candidates
# Checks
assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
# Settings
# min_wh = 2 # (pixels) minimum box width and height
max_wh = 7680 # (pixels) maximum box width and height
max_nms = 300 # maximum number of boxes into torchvision.ops.nms()
time_limit = 2.5 + 0.05 * bs # seconds to quit after
redundant = True # require redundant detections
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
merge = False # use merge-NMS
t = time.time()
output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
for xi, x in enumerate(prediction): # image index, image inference
# Apply constraints
# x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height
x = x.T[xc[xi]] # confidence
# Cat apriori labels if autolabelling
if labels and len(labels[xi]):
lb = labels[xi]
v = torch.zeros((len(lb), nc + nm + 5), device=x.device)
v[:, :4] = lb[:, 1:5] # box
v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 # cls
x = torch.cat((x, v), 0)
# If none remain process next image
if not x.shape[0]:
continue
# Detections matrix nx6 (xyxy, conf, cls)
box, cls, mask = x.split((4, nc, nm), 1)
box = xywh2xyxy(box) # center_x, center_y, width, height) to (x1, y1, x2, y2)
if multi_label:
i, j = (cls > conf_thres).nonzero(as_tuple=False).T
x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)
else: # best class only
conf, j = cls.max(1, keepdim=True)
x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
# Filter by class
if classes is not None:
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
# Check shape
n = x.shape[0] # number of boxes
if not n: # no boxes
continue
elif n > max_nms: # excess boxes
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
else:
x = x[x[:, 4].argsort(descending=True)] # sort by confidence
output[xi] = x
if mps:
output[xi] = output[xi].to(device)
if (time.time() - t) > time_limit:
LOGGER.warning(f'WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded')
break # time limit exceeded
return output
clean up the code.
def no_max_suppression(
prediction,
conf_thres=0.25,
iou_thres=0.45,
classes=None,
agnostic=False,
multi_label=False,
labels=(),
max_det=300,
nm=0, # number of masks
):
"""No Maximum Suppression on inference results
Returns:
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
"""
if isinstance(prediction, (list, tuple)): # YOLO model in validation model, output = (inference_out, loss_out)
prediction = prediction[0] # select only inference output
device = prediction.device
mps = 'mps' in device.type # Apple MPS
if mps: # MPS not fully supported yet, convert tensors to CPU before NMS
prediction = prediction.cpu()
bs = prediction.shape[0] # batch size
nc = prediction.shape[1] - nm - 4 # number of classes
mi = 4 + nc # mask start index
xc = prediction[:, 4:mi].amax(1) > conf_thres # candidates
# Checks
assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
# Settings
time_limit = 2.5 + 0.05 * bs # seconds to quit after
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
t = time.time()
output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
for xi, x in enumerate(prediction): # image index, image inference
# Apply constraints
x = x.T[xc[xi]] # confidence
# Cat apriori labels if autolabelling
if labels and len(labels[xi]):
lb = labels[xi]
v = torch.zeros((len(lb), nc + nm + 5), device=x.device)
v[:, :4] = lb[:, 1:5] # box
v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 # cls
x = torch.cat((x, v), 0)
# If none remain process next image
if not x.shape[0]:
continue
# Detections matrix nx6 (xyxy, conf, cls)
box, cls, mask = x.split((4, nc, nm), 1)
box = xywh2xyxy(box) # center_x, center_y, width, height) to (x1, y1, x2, y2)
if multi_label:
i, j = (cls > conf_thres).nonzero(as_tuple=False).T
x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)
else: # best class only
conf, j = cls.max(1, keepdim=True)
x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
# Filter by class
if classes is not None:
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
# Check shape
n = x.shape[0] # number of boxes
if not n: # no boxes
continue
elif n > max_det: # excess boxes
x = x[x[:, 4].argsort(descending=True)[:max_det]] # sort by confidence
else:
x = x[x[:, 4].argsort(descending=True)] # sort by confidence
output[xi] = x
if mps:
output[xi] = output[xi].to(device)
if (time.time() - t) > time_limit:
LOGGER.warning(f'WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded')
break # time limit exceeded
return output
I have retrained the model completely due to the change that eliminates the need to process the NMS. We pick the best model based on mAP.
Eval Model | AP | AP50 | Precision | Recall |
---|---|---|---|---|
Origin | 0.523 | 0.693 | 0.735 | 0.621 |
QAT-TRT | 0.52 | 0.686 | 0.735 | 0.617 |
Eval Model | AP | AP50 | Precision | Recall |
---|---|---|---|---|
Origin | 0.524 | 0.695 | 0.728 | 0.629 |
QAT-TRT | 0.521 | 0.687 | 0.731 | 0.618 |
Note: Despite the good results, the model is still not 100% quantized. We can further improve performance with minimal loss. Additionally, we can recover the loss due to quantization by applying other calibration methods.
lh-yolov9-c-coarse.pt, lh-yolov9-c-fine.pt are updated.
Eval Model | AP | AP50 | Precision | Recall |
---|---|---|---|---|
Origin | 0.526 | 0.696 | 0.743 | 0.628 |
PTQ (Baseline) | 0.514 | 0.681 | 0.724 | 0.613 |
QAT (PyTorch) | 0.517 | 0.683 | 0.744 | 0.604 |
QAT-TRT | 0.516 | 0.679 | 0.724 | 0.617 |
Eval Model | AP | AP50 | Precision | Recall |
---|---|---|---|---|
Origin | 0.526 | 0.696 | 0.743 | 0.628 |
PTQ (Baseline) | 0.517 | 0.686 | 0.742 | 0.615 |
QAT (PyTorch) | 0.518 | 0.687 | 0.738 | 0.616 |
QAT-TRT | 0.517 | 0.682 | 0.739 | 0.614 |
Eval Model | AP | AP50 | Precision | Recall |
---|---|---|---|---|
Origin | 0.528 | 0.700 | 0.743 | 0.634 |
PTQ (Baseline) | 0.524 | 0.696 | 0.734 | 0.631 |
QAT (PyTorch) | 0.526 | 0.697 | 0.741 | 0.631 |
QAT-TRT | 0.526 | 0.692 | 0.733 | 0.634 |
Eval Model | AP | AP50 | Precision | Recall |
---|---|---|---|---|
Origin | 0.528 | 0.700 | 0.743 | 0.634 |
PTQ (Baseline) | 0.525 | 0.697 | 0.742 | 0.628 |
QAT (PyTorch) | 0.527 | 0.699 | 0.742 | 0.634 |
QAT-TRT | 0.526 | 0.692 | 0.744 | 0.631 |
I initially performed the default MSE calibration, but the results were unsatisfactory. Consequently, I modified the calibration method to use percentile=99.999, which yielded better outcomes. I believe that the these model has more sensitive layers that need to be treated differently. Additionally, I need to explore the new HEAD of the model since I only performed quantization for YOLOv9-C/E.
I am generating a latency report.
Batch Size | Latency (percentile 99%) | Throughput (qps) | Total Throughput (IPS) |
---|---|---|---|
1 | 0.94 ms | 1059 | 1059 |
8 | 3.56 ms | 282 | 2256 |
12 | 5.18 ms | 194 | 2328 |
Batch Size | Latency (percentile 99%) | Throughput (qps) | Total Throughput (IPS) |
---|---|---|---|
1 | 0.95 | 1048 | 1048 |
8 | 3.56 ms | 281 | 2248 |
12 | 5.21 ms | 193 | 2316 |
Thanks!
It seems old weights (https://github.com/WongKinYiu/yolov9/issues/327#issuecomment-2144177000) have more stable QAT performance. Although new weights (https://github.com/WongKinYiu/yolov9/issues/327#issuecomment-2145924313) have higher original AP, they drop more performance after QAT.
Since old weights and new weights are trained by different strategies, maybe it is worth to discuss the relation between pretrain methods and QAT step.
I will provide weights of YOLOV9-C-FINE trained by same way as https://github.com/WongKinYiu/yolov9/issues/327#issuecomment-2145924313 in few days to make sure if sensitive layers are caused by different training methods.
If yes, I could try to analyze and design QAT friendly pretrained methods in the future.
Thank you for bring this possible research direction to me.
I ran the tests to find the most sensitive layer (PQT Baseline), and here are the results:
(https://github.com/WongKinYiu/yolov9/issues/327#issuecomment-2145133874) Top0: Using fp16 model.22, ap = 0.52310
Sensitive summary:
Top0: Using fp16 model.22, ap = 0.52310
Top1: Using fp16 model.4, ap = 0.51660
Top2: Using fp16 model.3, ap = 0.51650
Top3: Using fp16 model.1, ap = 0.51570
Top4: Using fp16 model.15, ap = 0.51560
Top5: Using fp16 model.17, ap = 0.51560
Top6: Using fp16 model.2, ap = 0.51550
Top7: Using fp16 model.8, ap = 0.51550
Top8: Using fp16 model.11, ap = 0.51550
Top9: Using fp16 model.14, ap = 0.51550
Top10: Using fp16 model.21, ap = 0.51550
Top11: Using fp16 model.0, ap = 0.51540
Top12: Using fp16 model.9, ap = 0.51540
Top13: Using fp16 model.18, ap = 0.51540
Top14: Using fp16 model.5, ap = 0.51530
Top15: Using fp16 model.6, ap = 0.51530
Top16: Using fp16 model.12, ap = 0.51530
Top17: Using fp16 model.19, ap = 0.51530
Top18: Using fp16 model.20, ap = 0.51520
Top19: Using fp16 model.7, ap = 0.51500
Top20: Using fp16 PTQ, ap = 0.51490
Top21: Using fp16 model.10, ap = 0.51490
Today my day was quite busy, but I believe I will be able to run the training with layer 22 using fp16 and see the performance and accuracy results.
Indeed, layer 22 is the most sensitive layer. I disabled the quantization in layer 22 and managed to recover the precision with better performance at batch size 1. However, when increasing the batch size to 8 or 12, there is a slight increase in latency and a decrease in throughput.
(https://github.com/WongKinYiu/yolov9/issues/327#issuecomment-2145133874)
Eval Model | AP | AP50 | Precision | Recall |
---|---|---|---|---|
Origin | 0.526 | 0.696 | 0.743 | 0.628 |
PTQ (Baseline) | 0.522 | 0.693 | 0.74 | 0.627 |
QAT (PyTorch) | 0.524 | 0.694 | 0.738 | 0.626 |
QAT-TRT | 0.525 | 0.69 | 0.743 | 0.622 |
Batch Size | Latency (percentile 99%) | Throughput (qps) | Total Throughput (IPS) |
---|---|---|---|
1 | 0.92 ms | 1079 | 1079 |
8 | 3.76 ms | 266 | 2128 |
12 | 5.65 ms | 178 | 2136 |
I upload the weights and update the file name.
training method 1: yolov9-c-coarse.pt, yolov9-c-fine.pt, lh-yolov9-c-coarse.pt, lh-yolov9-c-fine.pt.
training method 2: yolov9-c-coarse-.pt, yolov9-c-fine-.pt, lh-yolov9-c-coarse-.pt, lh-yolov9-c-fine-.pt.
By the way, could you help for examine latency/throughput of tiny/small/medium models also. yolov9-t-converted.pt, yolov9-s-converted.pt, yolov9-m-converted.pt.
Thanks.
I have observed that the last layer of model is often the most sensitive to quantization. This sensitivity arises because this layer tends to generate more outliers. From a quantization perspective, these outliers are normalized, leading to a loss of precision, as these outliers are crucial for the model’s accuracy.
By changing the training method, we have effectively reduced the generation of outliers, which are critical for quantization. The different training approach has shown to produce fewer values that are considered outliers, thus preserving the precision and overall performance of the quantized model.
To address the sensitivity of the final layer to quantization, I implemented a straightforward approach: disabling the quantization of layer 22. Instead of retraining the model, I simply disabled the quantization for this specific layer and re-evaluated the model to assess the impact on performance.
Quantization Disabled at layer 22 is indicated by the suffix -D22
.
I performed the calibration using MSE, although in some cases, using percentile = 99.999 proved to be more efficient.
Eval Model | AP | AP50 | Precision | Recall |
---|---|---|---|---|
Origin | 0.527 | 0.699 | 0.74 | 0.633 |
PTQ | 0.523 | 0.696 | 0.733 | 0.629 |
QAT-PyT | 0.525 | 0.697 | 0.741 | 0.624 |
QAT-TRT | 0.524 | 0.692 | 0.722 | 0.638 |
QAT-PyT-D22 | 0.525 | 0.694 | 0.732 | 0.631 |
QAT-TRT-D22 | 0.524 | 0.692 | 0.725 | 0.635 |
Eval Model | AP | AP50 | Precision | Recall |
---|---|---|---|---|
Origin | 0.527 | 0.693 | 0.728 | 0.636 |
PTQ | 0.525 | 0.692 | 0.739 | 0.630 |
QAT-PyT | 0.526 | 0.692 | 0.738 | 0.629 |
QAT-TRT | 0.526 | 0.693 | 0.725 | 0.638 |
QAT-PyT-D22 | 0.526 | 0.693 | 0.730 | 0.635 |
QAT-TRT-D22 | 0.526 | 0.693 | 0.730 | 0.634 |
Eval Model | AP | AP50 | Precision | Recall |
---|---|---|---|---|
Origin | 0.527 | 0.699 | 0.724 | 0.642 |
PTQ | 0.524 | 0.696 | 0.728 | 0.628 |
QAT-PyT | 0.525 | 0.697 | 0.731 | 0.633 |
QAT-TRT | 0.524 | 0.692 | 0.722 | 0.638 |
QAT-PyT-D22 | 0.525 | 0.694 | 0.715 | 0.639 |
QAT-TRT-D22 | 0.524 | 0.693 | 0.722 | 0.638 |
Eval Model | AP | AP50 | Precision | Recall |
---|---|---|---|---|
Origin | 0.528 | 0.696 | 0.743 | 0.634 |
PTQ | 0.524 | 0.693 | 0.734 | 0.631 |
QAT-PyT | 0.526 | 0.693 | 0.741 | 0.631 |
QAT-TRT | 0.525 | 0.693 | 0.734 | 0.634 |
QAT-PyT-D22 | 0.527 | 0.694 | 0.742 | 0.631 |
QAT-TRT-D22 | 0.526 | 0.693 | 0.734 | 0.634 |
Eval Model | AP | AP50 | Precision | Recall |
---|---|---|---|---|
Origin | 0.523 | 0.689 | 0.735 | 0.621 |
PTQ | 0.519 | 0.685 | 0.723 | 0.627 |
QAT-PyT | 0.520 | 0.687 | 0.740 | 0.619 |
QAT-TRT | 0.520 | 0.686 | 0.734 | 0.619 |
QAT-PyT-D22 | 0.520 | 0.686 | 0.737 | 0.619 |
QAT-TRT-D22 | 0.520 | 0.686 | 0.734 | 0.620 |
Eval Model | AP | AP50 | Precision | Recall |
---|---|---|---|---|
Origin | 0.523 | 0.688 | 0.733 | 0.626 |
PTQ | 0.521 | 0.685 | 0.725 | 0.623 |
QAT-PyT | 0.522 | 0.686 | 0.734 | 0.621 |
QAT-TRT | 0.522 | 0.685 | 0.734 | 0.615 |
QAT-PyT-D22 | 0.522 | 0.686 | 0.711 | 0.629 |
QAT-TRT-D22 | 0.522 | 0.685 | 0.726 | 0.620 |
Eval Model | AP | AP50 | Precision | Recall |
---|---|---|---|---|
Origin | 0.524 | 0.695 | 0.728 | 0.629 |
PTQ | 0.520 | 0.691 | 0.730 | 0.620 |
QAT-PyT | 0.521 | 0.691 | 0.740 | 0.614 |
QAT-TRT | 0.521 | 0.687 | 0.741 | 0.616 |
QAT-PyT-D22 | 0.522 | 0.689 | 0.733 | 0.620 |
QAT-TRT-D22 | 0.523 | 0.689 | 0.741 | 0.615 |
Eval Model | AP | AP50 | Precision | Recall |
---|---|---|---|---|
Origin | 0.526 | 0.692 | 0.743 | 0.628 |
PTQ | 0.515 | 0.678 | 0.724 | 0.613 |
QAT-PyT | 0.517 | 0.679 | 0.735 | 0.611 |
QAT-TRT | 0.516 | 0.679 | 0.730 | 0.612 |
QAT-PyT-D22 | 0.524 | 0.690 | 0.752 | 0.620 |
QAT-TRT-D22 | 0.524 | 0.690 | 0.750 | 0.620 |
I still owe the tests for the remaining models as well as the latency tests, which I will send as soon as possible.
I have encountered several performance issues regarding latency and throughput in the quantized Tiny, Small, and Medium models. They performed worse than the FP16 models, generating many reformat operations that directly impacted the model's latency. I am currently researching and studying the behavior of quantization in these models to resolve the issue.
Batch Size | Latency (percentile 99%) | Throughput (qps) | Total Throughput (IPS) |
---|---|---|---|
1 | 0.78 ms | 1282 | 1282 |
12 | 2.32 ms | 432 | 5184 |
24 | 4.40 ms | 228 | 5472 |
32 | 5.94 ms | 169 | 5408 |
Batch Size | Latency (percentile 99%) | Throughput (qps) | Total Throughput (IPS) |
---|---|---|---|
1 | 0.92 ms | 1086 | 1086 |
12 | 4.19 ms | 240 | 2880 |
24 | 8.25 ms | 122 | 2928 |
32 | 11.07 ms | 91 | 2912 |
Batch Size | Latency (percentile 99%) | Throughput (qps) | Total Throughput (IPS) |
---|---|---|---|
1 | 1.22 ms | 818 | 818 |
8 | 5.81 ms | 175 | 1400 |
12 | 8.78 ms | 116 | 1392 |
24 | 17.90 ms | 57 | 1368 |
32 | 24.09 ms | 42 | 1344 |
Eval Model | AP | AP50 | Precision | Recall |
---|---|---|---|---|
Origin | 0.38 | 0.527 | 0.633 | 0.481 |
PTQ | 0.374 | 0.52 | 0.63 | 0.475 |
QAT - Best | 0.376 | 0.523 | 0.641 | 0.473 |
QAT - TRT | 0.376 | 0.523 | 0.64 | 0.476 |
Eval Model | AP | AP50 | Precision | Recall |
---|---|---|---|---|
Origin | 0.467 | 0.629 | 0.698 | 0.569 |
PTQ | 0.462 | 0.622 | 0.689 | 0.565 |
QAT - Best | 0.465 | 0.626 | 0.688 | 0.576 |
QAT - TRT | 0.465 | 0.626 | 0.692 | 0.57 |
Eval Model | AP | AP50 | Precision | Recall |
---|---|---|---|---|
Origin | 0.513 | 0.677 | 0.713 | 0.62 |
PTQ | 0.512 | 0.675 | 0.713 | 0.619 |
QAT - Best | 0.512 | 0.675 | 0.712 | 0.621 |
QAT - TRT | 0.511 | 0.675 | 0.712 | 0.622 |
Thanks.
Yes, the throughput seems strange. yolov9-m gets same throughput as tolov9-c.
The main difference between t/s/m and c/e is that t/s/m use AConv and c/e use ADown for downsampling. In Pytorch AConv is faster than ADown, but I am not sure the situation on trt and quantization.
I have noticed that model performance is often measured solely by latency. However, during my research, I discovered that different models can have very similar latencies with a batch size of 1. But as the batch size increases, they show significant differences in both throughput and latency. Therefore, testing only with a batch size of 1 and focusing solely on latency can lead to incorrect conclusions about the model's potential.
To accurately measure a model's potential, we should consider both latency and batch size. On the GPU, models can have a certain latency, but increasing the batch size doesn't cause latency to grow proportionally. This is evident in the performance tables. Thus, the best model is the one that achieves the highest throughput with the largest batch size and the lowest latency.
I will attempt to illustrate my finds visually.
When SM Active reaches 100%, the model's performance drops, resulting in increased latency and decreased throughput.
Therefore, when measuring the potential of the model, we should also consider the batch size. The best model is the one that achieves the highest throughput with the largest batch size and the lowest latency.
yolov9-m gets same throughput as tolov9-c.
I will perform profiling to see the differences.
Batch 1 and large batch are both important.
Large batch inference is importance on cloud service. Batch 1 inference is important for streaming input.
yolov9-m gets same throughput as tolov9-c
I have encounter similar issues (small model and large model have same inference speed) on yolov4 when using some build-in pytorch version in nvidia docker. My solution is reinstall pytorch and related dependency. I am not sure if you face same issue.
I have encounter similar issues (small model and large model have same inference speed) on yolov4 when using some build-in pytorch version in nvidia docker.
I will test these models on different servers and TensorRT version.
I often see performance reports comparing perfomance between YOLO Series models with a batch size of 1, using latency as the primary comparison parameter. However, without testing the variable batch size, it's possible that some models may have significantly worse performance when using larger batch sizes compared to others. A classic example was test of batch size 1 on YOLOv9-t with a latency of 0.7 ms versus YOLOv9-s with 0.9 ms and the throughput difference was only about 280 IPS. However, when tested with a batch size of 12, the YOLOv9-t latency was almost double that of YOLOv9-s, with YOLOv9-t at 2.32 ms and YOLOv9-s at 4.19 ms. In terms of throughput, the difference was significant, with YOLOv9-t achieving nearly 2500 more IPS than YOLOv9-s, with YOLOv9-t at 5184 IPS and YOLOv9-s at 2880 IPS. By focusing solely on latency at batch size 1, I would be overlooking the full potential of YOLOv9-t in terms of performance.
laugh12321 gets similar inference speed as your reports.
Three possible reasons:
Since c model has 13 times flops of t model, it really strange.
I have never meet this situation on our platform.
Could you help for trying to switch to root user sudo -s
and test speed.
To check if the number of layer is the one of reason, could you help for test gelan-s2.pt.
number of layers: e > t = s > s1 = c > m
could you help for test gelan-s2.pt.
Batch Size | Latency (percentile 99%) | Throughput (qps) | Total Throughput (IPS) |
---|---|---|---|
1 | 0.805 ms | 1242 | 1242 |
12 | 3.872 ms | 259 | 3108 |
24 | 7.766 ms | 130 | 3120 |
32 | 10.388 ms | 97 | 3104 |
Since c model has 13 times flops of t model, it really strange. Could you help for trying to switch to root user sudo -s and test speed.
I don't believe the problem is with the host or the installation. Maybe be some bug/issue in TensorRT, because only a few models exhibit this strange behavior. Will install TensorRT Engine Explorer and get results.
I'm having a lot of difficulty identifying why the t/s/m model is performing poorly when quantized. I've noticed a lot reformat operations due different scales. I implemented AConv similar to ADown, but the poor results persist. I also observed some DFL operations in the slice of the initial layers what differ from Yolov9-c. However, I'm still investigating this carefully.
Theses Reformat are killing me
Thank you for your effort.
Yes, it seems there are many unnecessary reformat layers are generated by tensorrt.
I am not sure if this help.
"It is possible to make TensorRT avoid inserting reformatting at the network boundaries, by setting the builder configuration flag DIRECT_IO
. "
about https://github.com/WongKinYiu/yolov9/issues/327#issuecomment-2156066103 I performed profiling for each model individually and then conducted a comparative analysis between the models: C vs. M M vs. S S vs. T Google Drive files: https://drive.google.com/drive/folders/18vBxAWZmQ1KUV7Tga9yH_fL5YSbbdYzw?usp=sharing
Well, do not know why after convert to tensorrt, yolov9-m has many layers.
I have been analyzing the models and noticed that YOLOv9-C vs. YOLOv9-M has several Reformat operations where some nodes were not fused. The same issue occurs with the QAT models, where some nodes, despite being on the same scale, are not being fused, resulting in multiple Reformat operations.
I searched on GitHub and found several users experiencing issues with node fusion, where TensorRT did not support certain fusions. Given that these models introduce new modules, it is possible that this has caused issues with TensorRT.
We need to open another front to address this issue in the TensorRT repository to understand where the potential problem lies.
Could you help for take a look if YOLOv7 have same issue. If no, I could point out the main difference between YOLOv7 and YOLOv9 architectures.
These past few days I was away on a business trip. I'm returning now and we will pick up where we left off. I'm sorry for the delay in responding.
YOLOv9 with Quantization-Aware Training (QAT) for TensorRT
https://github.com/levipereira/yolov9-qat/ This repository hosts an implementation of YOLOv9 integrated with Quantization-Aware Training (QAT), optimized for deployment on TensorRT-supported platforms to achieve hardware-accelerated inference. It aims to deliver an efficient, low-latency version of YOLOv9 for real-time object detection applications. If you're not planning to deploy your model using TensorRT, it's advisable not to proceed with this implementation.
Implementation Details:
Perfomance Report
@WongKinYiu I've successfully created a comprehensive implementation of Quantization in a separate repository. It works as a patch for the original YOLOv9 version. However, there are still some challenges to address as the implementation is functional but has room for improvement.
I'm closing the issue #253 and will continue the discussion in this thread. If possible, please replace the reference to issue #253 with this new issue #327 in the Useful Links section.
I'll provide the latency reports shortly.