Deci-AI / super-gradients

Easily train or fine-tune SOTA computer vision models with one open source training library. The home of Yolo-NAS.
https://www.supergradients.com
Apache License 2.0
4.53k stars 495 forks source link

YOLONAS: load `ckpt_best.pth` after doing `qat_from_recepie` and export with different settings #1808

Open nsabir2011 opened 7 months ago

nsabir2011 commented 7 months ago

💡 Your Question

After training and applying QAT with python -m super_gradinets.qat_from_recepie, the model is exported automatically. However, I don't have control over how it is exported. For example, the exported model had a batch size of 12 and didn't come with pre and post processing steps included. But I need that as I don't know what the pre and post processing steps are.

So I tried to load and export the model with the following code (as per this guide):

yolonas = models.get(Models.YOLO_NAS_S, checkpoint_path="<path-to-ckpt_best.pth>", num_classes=80)
yolonas.export("yolonas_s_int8.onnx", preprocessing=True, postprocessing=True, engine=ExportTargetBackend.TENSORRT, quantization_mode=ExportQuantizationMode.INT8)

But am getting error:

RuntimeError                              Traceback (most recent call last)
File ~/venvs/s-grads/lib/python3.10/site-packages/super_gradients/training/utils/checkpoint_utils.py:99, in adaptive_load_state_dict(net, state_dict, strict, solver)
     98     strict_bool = strict if isinstance(strict, bool) else strict != StrictLoad.OFF
---> 99     net.load_state_dict(state_dict, strict=strict_bool)
    100 except (RuntimeError, ValueError, KeyError) as ex:

File ~/venvs/s-grads/lib/python3.10/site-packages/torch/nn/modules/module.py:2152, in Module.load_state_dict(self, state_dict, strict, assign)
   2151 if len(error_msgs) > 0:
-> 2152     raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
   2153                        self.__class__.__name__, "\n\t".join(error_msgs)))
   2154 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for YoloNAS_S:
        Missing key(s) in state_dict: "backbone.stem.conv.branch_3x3.conv.weight", "backbone.stem.conv.branch_3x3.bn.weight", "backbone.stem.conv.branch_3x3.bn.bias", "backbone.stem.conv.branch_3x3.bn.running_mean", "backbone.stem.conv.branch_3x3.bn.running_var", "backbone.stem.conv.branch_1x1.weight", "backbone.stem.conv.branch_1x1.bias", "backbone.stage1.downsample.branch_3x3.conv.weight", "backbone.stage1.downsample.branch_3x3.bn.weight", "backbone.stage1.downsample.branch_3x3.bn.bias", "backbone.stage1.downsample.branch_3x3.bn.running_mean", "backbone.stage1.downsample.branch_3x3.bn.running_var", "backbone.stage1.downsample.branch_1x1.weight", "backbone.stage1.downsample.branch_1x1.bias", "backbone.stage1.blocks.bottlenecks.0.cv1.branch_3x3.conv.weight", "backbone.stage1.blocks.bottlenecks.0.cv1.branch_3x3.bn.weight", "backbone.stage1.blocks.bottlenecks.0.cv1.branch_3x3.bn.bias", "backbone.stage1.blocks.bottlenecks.0.cv1.branch_3x3.bn.running_mean", "backbone.stage1.blocks.bottlenecks.0.cv1.branch_3x3.bn.running_var", "backbone.stage1.blocks.bottlenecks.0.cv1.branch_1x1.weight", "backbone.stage1.blocks.bottlenecks.0.cv1.branch_1x1.bias", "backbone.stage1.blocks.bottlenecks.0.cv2.branch_3x3.conv.weight", "backbone.stage1.blocks.bottlenecks.0.cv2.branch_3x3.bn.weight", "backbone.stage1.blocks.bottlenecks.0.cv2.branch_3x3.bn.bias", "backbone.stage1.blocks.bottlenecks.0.cv2.branch_3x3.bn.running_mean", "backbone.stage1.blocks.bottlenecks.0.cv2.branch_3x3.bn.running_var", "backbone.stage1.blocks.bottlenecks.0.cv2.branch_1x1.weight", "backbone.stage1.blocks.bottlenecks.0.cv2.branch_1x1.bias", "backbone.stage1.blocks.bottlenecks.1.cv1.branch_3x3.conv.weight", "backbone.stage1.blocks.bottlenecks.1.cv1.branch_3x3.bn.weight", "backbone.stage1.blocks.bottlenecks.1.cv1.branch_3x3.bn.bias", "backbone.stage1.blocks.bottlenecks.1.cv1.branch_3x3.bn.running_mean", "backbone.stage1.blocks.bottlenecks.1.cv1.branch_3x3.bn.running_var", "backbone.stage1.blocks.bottlenecks.1.cv1.branch_1x1.weight", "backbone.stage1.blocks.bottlenecks.1.cv1.branch_1x1.bias", "backbone.stage1.blocks.bottlenecks.1.cv2.branch_3x3.conv.weight", "backbone.stage1.blocks.bottlenecks.1.cv2.branch_3x3.bn.weight", "backbone.stage1.blocks.bottlenecks.1.cv2.branch_3x3.bn.bias", "backbone.stage1.blocks.bottlenecks.1.cv2.branch_3x3.bn.running_mean", "backbone.stage1.blocks.bottlenecks.1.cv2.branch_3x3.bn.running_var", "backbone.stage1.blocks.bottlenecks.1.cv2.branch_1x1.weight", "backbone.stage1.blocks.bottlenecks.1.cv2.branch_1x1.bias", "backbone.stage2.downsample.branch_3x3.conv.weight", "backbone.stage2.downsample.branch_3x3.bn.weight", "backbone.stage2.downsample.branch_3x3.bn.bias", "backbone.stage2.downsample.branch_3x3.bn.running_mean", "backbone.stage2.downsample.branch_3x3.bn.running_var", "backbone.stage2.downsample.branch_1x1.weight", "backbone.stage2.downsample.branch_1x1.bias", "backbone.stage2.blocks.bottlenecks.0.cv1.branch_3x3.conv.weight", "backbone.stage2.blocks.bottlenecks.0.cv1.branch_3x3.bn.weight", "backbone.stage2.blocks.bottlenecks.0.cv1.branch_3x3.bn.bias", "backbone.stage2.blocks.bottlenecks.0.cv1.branch_3x3.bn.running_mean", "backbone.stage2.blocks.bottlenecks.0.cv1.branch_3x3.bn.running_var", "backbone.stage2.blocks.bottlenecks.0.cv1.branch_1x1.weight", "backbone.stage2.blocks.bottlenecks.0.cv1.branch_1x1.bias", "backbone.stage2.blocks.bottlenecks.0.cv2.branch_3x3.conv.weight", "backbone.stage2.blocks.bottlenecks.0.cv2.branch_3x3.bn.weight", "backbone.stage2.blocks.bottlenecks.0.cv2.branch_3x3.bn.bias", "backbone.stage2.blocks.bottlenecks.0.cv2.branch_3x3.bn.running_mean", "backbone.stage2.blocks.bottlenecks.0.cv2.branch_3x3.bn.running_var", "backbone.stage2.blocks.bottlenecks.0.cv2.branch_1x1.weight", "backbone.stage2.blocks.bottlenecks.0.cv2.branch_1x1.bias", "backbone.stage2.blocks.bottlenecks.1.cv1.branch_3x3.conv.weight", "backbone.stage2.blocks.bottlenecks.1.cv1.branch_3x3.bn.weight", "backbone.stage2.blocks.bottlenecks.1.cv1.branch_3x3.bn.bias", "backbone.stage2.blocks.bottlenecks.1.cv1.branch_3x3.bn.running_mean", "backbone.stage2.blocks.bottlenecks.1.cv1.branch_3x3.bn.running_var", "backbone.stage2.blocks.bottlenecks.1.cv1.branch_1x1.weight", "backbone.stage2.blocks.bottlenecks.1.cv1.branch_1x1.bias", "backbone.stage2.blocks.bottlenecks.1.cv2.branch_3x3.conv.weight", "backbone.stage2.blocks.bottlenecks.1.cv2.branch_3x3.bn.weight", "backbone.stage2.blocks.bottlenecks.1.cv2.branch_3x3.bn.bias", "backbone.stage2.blocks.bottlenecks.1.cv2.branch_3x3.bn.running_mean", "backbone.stage2.blocks.bottlenecks.1.cv2.branch_3x3.bn.running_var", "backbone.stage2.blocks.bottlenecks.1.cv2.branch_1x1.weight", "backbone.stage2.blocks.bottlenecks.1.cv2.branch_1x1.bias", "backbone.stage2.blocks.bottlenecks.2.cv1.branch_3x3.conv.weight", "backbone.stage2.blocks.bottlenecks.2.cv1.branch_3x3.bn.weight", "backbone.stage2.blocks.bottlenecks.2.cv1.branch_3x3.bn.bias", "backbone.stage2.blocks.bottlenecks.2.cv1.branch_3x3.bn.running_mean", "backbone.stage2.blocks.bottlenecks.2.cv1.branch_3x3.bn.running_var", "backbone.stage2.blocks.bottlenecks.2.cv1.branch_1x1.weight", "backbone.stage2.blocks.bottlenecks.2.cv1.branch_1x1.bias", "backbone.stage2.blocks.bottlenecks.2.cv2.branch_3x3.conv.weight", "backbone.stage2.blocks.bottlenecks.2.cv2.branch_3x3.bn.weight", "backbone.stage2.blocks.bottlenecks.2.cv2.branch_3x3.bn.bias", "backbone.stage2.blocks.bottlenecks.2.cv2.branch_3x3.bn.running_mean", "backbone.stage2.blocks.bottlenecks.2.cv2.branch_3x3.bn.running_var", "backbone.stage2.blocks.bottlenecks.2.cv2.branch_1x1.weight", "backbone.stage2.blocks.bottlenecks.2.cv2.branch_1x1.bias", "backbone.stage3.downsample.branch_3x3.conv.weight", "backbone.stage3.downsample.branch_3x3.bn.weight", "backbone.stage3.downsample.branch_3x3.bn.bias", "backbone.stage3.downsample.branch_3x3.bn.running_mean", "backbone.stage3.downsample.branch_3x3.bn.running_var", "backbone.stage3.downsample.branch_1x1.weight", "backbone.stage3.downsample.branch_1x1.bias", "backbone.stage3.blocks.bottlenecks.0.cv1.branch_3x3.conv.weight", "backbone.stage3.blocks.bottlenecks.0.cv1.branch_3x3.bn.weight", "backbone.stage3.blocks.bottlenecks.0.cv1.branch_3x3.bn.bias", "backbone.stage3.blocks.bottlenecks.0.cv1.branch_3x3.bn.running_mean", "backbone.stage3.blocks.bottlenecks.0.cv1.branch_3x3.bn.running_var", "backbone.stage3.blocks.bottlenecks.0.cv1.branch_1x1.weight", "backbone.stage3.blocks.bottlenecks.0.cv1.branch_1x1.bias", "backbone.stage3.blocks.bottlenecks.0.cv2.branch_3x3.conv.weight", "backbone.stage3.blocks.bottlenecks.0.cv2.branch_3x3.bn.weight", "backbone.stage3.blocks.bottlenecks.0.cv2.branch_3x3.bn.bias", "backbone.stage3.blocks.bottlenecks.0.cv2.branch_3x3.bn.running_mean", "backbone.stage3.blocks.bottlenecks.0.cv2.branch_3x3.bn.running_var", "backbone.stage3.blocks.bottlenecks.0.cv2.branch_1x1.weight", "backbone.stage3.blocks.bottlenecks.0.cv2.branch_1x1.bias", "backbone.stage3.blocks.bottlenecks.1.cv1.branch_3x3.conv.weight", "backbone.stage3.blocks.bottlenecks.1.cv1.branch_3x3.bn.weight", "backbone.stage3.blocks.bottlenecks.1.cv1.branch_3x3.bn.bias", "backbone.stage3.blocks.bottlenecks.1.cv1.branch_3x3.bn.running_mean", "backbone.stage3.blocks.bottlenecks.1.cv1.branch_3x3.bn.running_var", "backbone.stage3.blocks.bottlenecks.1.cv1.branch_1x1.weight", "backbone.stage3.blocks.bottlenecks.1.cv1.branch_1x1.bias", "backbone.stage3.blocks.bottlenecks.1.cv2.branch_3x3.conv.weight", "backbone.stage3.blocks.bottlenecks.1.cv2.branch_3x3.bn.weight", "backbone.stage3.blocks.bottlenecks.1.cv2.branch_3x3.bn.bias", "backbone.stage3.blocks.bottlenecks.1.cv2.branch_3x3.bn.running_mean", "backbone.stage3.blocks.bottlenecks.1.cv2.branch_3x3.bn.running_var", "backbone.stage3.blocks.bottlenecks.1.cv2.branch_1x1.weight", "backbone.stage3.blocks.bottlenecks.1.cv2.branch_1x1.bias", "backbone.stage3.blocks.bottlenecks.2.cv1.branch_3x3.conv.weight", "backbone.stage3.blocks.bottlenecks.2.cv1.branch_3x3.bn.weight", "backbone.stage3.blocks.bottlenecks.2.cv1.branch_3x3.bn.bias", "backbone.stage3.blocks.bottlenecks.2.cv1.branch_3x3.bn.running_mean", "backbone.stage3.blocks.bottlenecks.2.cv1.branch_3x3.bn.running_var", "backbone.stage3.blocks.bottlenecks.2.cv1.branch_1x1.weight", "backbone.stage3.blocks.bottlenecks.2.cv1.branch_1x1.bias", "backbone.stage3.blocks.bottlenecks.2.cv2.branch_3x3.conv.weight", "backbone.stage3.blocks.bottlenecks.2.cv2.branch_3x3.bn.weight", "backbone.stage3.blocks.bottlenecks.2.cv2.branch_3x3.bn.bias", "backbone.stage3.blocks.bottlenecks.2.cv2.branch_3x3.bn.running_mean", "backbone.stage3.blocks.bottlenecks.2.cv2.branch_3x3.bn.running_var", "backbone.stage3.blocks.bottlenecks.2.cv2.branch_1x1.weight", "backbone.stage3.blocks.bottlenecks.2.cv2.branch_1x1.bias", "backbone.stage3.blocks.bottlenecks.3.cv1.branch_3x3.conv.weight", "backbone.stage3.blocks.bottlenecks.3.cv1.branch_3x3.bn.weight", "backbone.stage3.blocks.bottlenecks.3.cv1.branch_3x3.bn.bias", "backbone.stage3.blocks.bottlenecks.3.cv1.branch_3x3.bn.running_mean", "backbone.stage3.blocks.bottlenecks.3.cv1.branch_3x3.bn.running_var", "backbone.stage3.blocks.bottlenecks.3.cv1.branch_1x1.weight", "backbone.stage3.blocks.bottlenecks.3.cv1.branch_1x1.bias", "backbone.stage3.blocks.bottlenecks.3.cv2.branch_3x3.conv.weight", "backbone.stage3.blocks.bottlenecks.3.cv2.branch_3x3.bn.weight", "backbone.stage3.blocks.bottlenecks.3.cv2.branch_3x3.bn.bias", "backbone.stage3.blocks.bottlenecks.3.cv2.branch_3x3.bn.running_mean", "backbone.stage3.blocks.bottlenecks.3.cv2.branch_3x3.bn.running_var", "backbone.stage3.blocks.bottlenecks.3.cv2.branch_1x1.weight", "backbone.stage3.blocks.bottlenecks.3.cv2.branch_1x1.bias", "backbone.stage3.blocks.bottlenecks.4.cv1.branch_3x3.conv.weight", "backbone.stage3.blocks.bottlenecks.4.cv1.branch_3x3.bn.weight", "backbone.stage3.blocks.bottlenecks.4.cv1.branch_3x3.bn.bias", "backbone.stage3.blocks.bottlenecks.4.cv1.branch_3x3.bn.running_mean", "backbone.stage3.blocks.bottlenecks.4.cv1.branch_3x3.bn.running_var", "backbone.stage3.blocks.bottlenecks.4.cv1.branch_1x1.weight", "backbone.stage3.blocks.bottlenecks.4.cv1.branch_1x1.bias", "backbone.stage3.blocks.bottlenecks.4.cv2.branch_3x3.conv.weight", "backbone.stage3.blocks.bottlenecks.4.cv2.branch_3x3.bn.weight", "backbone.stage3.blocks.bottlenecks.4.cv2.branch_3x3.bn.bias", "backbone.stage3.blocks.bottlenecks.4.cv2.branch_3x3.bn.running_mean", "backbone.stage3.blocks.bottlenecks.4.cv2.branch_3x3.bn.running_var", "backbone.stage3.blocks.bottlenecks.4.cv2.branch_1x1.weight", "backbone.stage3.blocks.bottlenecks.4.cv2.branch_1x1.bias", "backbone.stage4.downsample.branch_3x3.conv.weight", "backbone.stage4.downsample.branch_3x3.bn.weight", "backbone.stage4.downsample.branch_3x3.bn.bias", "backbone.stage4.downsample.branch_3x3.bn.running_mean", "backbone.stage4.downsample.branch_3x3.bn.running_var", "backbone.stage4.downsample.branch_1x1.weight", "backbone.stage4.downsample.branch_1x1.bias", "backbone.stage4.blocks.bottlenecks.0.cv1.branch_3x3.conv.weight", "backbone.stage4.blocks.bottlenecks.0.cv1.branch_3x3.bn.weight", "backbone.stage4.blocks.bottlenecks.0.cv1.branch_3x3.bn.bias", "backbone.stage4.blocks.bottlenecks.0.cv1.branch_3x3.bn.running_mean", "backbone.stage4.blocks.bottlenecks.0.cv1.branch_3x3.bn.running_var", "backbone.stage4.blocks.bottlenecks.0.cv1.branch_1x1.weight", "backbone.stage4.blocks.bottlenecks.0.cv1.branch_1x1.bias", "backbone.stage4.blocks.bottlenecks.0.cv2.branch_3x3.conv.weight", "backbone.stage4.blocks.bottlenecks.0.cv2.branch_3x3.bn.weight", "backbone.stage4.blocks.bottlenecks.0.cv2.branch_3x3.bn.bias", "backbone.stage4.blocks.bottlenecks.0.cv2.branch_3x3.bn.running_mean", "backbone.stage4.blocks.bottlenecks.0.cv2.branch_3x3.bn.running_var", "backbone.stage4.blocks.bottlenecks.0.cv2.branch_1x1.weight", "backbone.stage4.blocks.bottlenecks.0.cv2.branch_1x1.bias", "backbone.stage4.blocks.bottlenecks.1.cv1.branch_3x3.conv.weight", "backbone.stage4.blocks.bottlenecks.1.cv1.branch_3x3.bn.weight", "backbone.stage4.blocks.bottlenecks.1.cv1.branch_3x3.bn.bias", "backbone.stage4.blocks.bottlenecks.1.cv1.branch_3x3.bn.running_mean", "backbone.stage4.blocks.bottlenecks.1.cv1.branch_3x3.bn.running_var", "backbone.stage4.blocks.bottlenecks.1.cv1.branch_1x1.weight", "backbone.stage4.blocks.bottlenecks.1.cv1.branch_1x1.bias", "backbone.stage4.blocks.bottlenecks.1.cv2.branch_3x3.conv.weight", "backbone.stage4.blocks.bottlenecks.1.cv2.branch_3x3.bn.weight", "backbone.stage4.blocks.bottlenecks.1.cv2.branch_3x3.bn.bias", "backbone.stage4.blocks.bottlenecks.1.cv2.branch_3x3.bn.running_mean", "backbone.stage4.blocks.bottlenecks.1.cv2.branch_3x3.bn.running_var", "backbone.stage4.blocks.bottlenecks.1.cv2.branch_1x1.weight", "backbone.stage4.blocks.bottlenecks.1.cv2.branch_1x1.bias", "neck.neck1.blocks.bottlenecks.0.cv1.branch_3x3.conv.weight", "neck.neck1.blocks.bottlenecks.0.cv1.branch_3x3.bn.weight", "neck.neck1.blocks.bottlenecks.0.cv1.branch_3x3.bn.bias", "neck.neck1.blocks.bottlenecks.0.cv1.branch_3x3.bn.running_mean", "neck.neck1.blocks.bottlenecks.0.cv1.branch_3x3.bn.running_var", "neck.neck1.blocks.bottlenecks.0.cv1.branch_1x1.weight", "neck.neck1.blocks.bottlenecks.0.cv1.branch_1x1.bias", "neck.neck1.blocks.bottlenecks.0.cv2.branch_3x3.conv.weight", "neck.neck1.blocks.bottlenecks.0.cv2.branch_3x3.bn.weight", "neck.neck1.blocks.bottlenecks.0.cv2.branch_3x3.bn.bias", "neck.neck1.blocks.bottlenecks.0.cv2.branch_3x3.bn.running_mean", "neck.neck1.blocks.bottlenecks.0.cv2.branch_3x3.bn.running_var", "neck.neck1.blocks.bottlenecks.0.cv2.branch_1x1.weight", "neck.neck1.blocks.bottlenecks.0.cv2.branch_1x1.bias", "neck.neck1.blocks.bottlenecks.1.cv1.branch_3x3.conv.weight", "neck.neck1.blocks.bottlenecks.1.cv1.branch_3x3.bn.weight", "neck.neck1.blocks.bottlenecks.1.cv1.branch_3x3.bn.bias", "neck.neck1.blocks.bottlenecks.1.cv1.branch_3x3.bn.running_mean", "neck.neck1.blocks.bottlenecks.1.cv1.branch_3x3.bn.running_var", "neck.neck1.blocks.bottlenecks.1.cv1.branch_1x1.weight", "neck.neck1.blocks.bottlenecks.1.cv1.branch_1x1.bias", "neck.neck1.blocks.bottlenecks.1.cv2.branch_3x3.conv.weight", "neck.neck1.blocks.bottlenecks.1.cv2.branch_3x3.bn.weight", "neck.neck1.blocks.bottlenecks.1.cv2.branch_3x3.bn.bias", "neck.neck1.blocks.bottlenecks.1.cv2.branch_3x3.bn.running_mean", "neck.neck1.blocks.bottlenecks.1.cv2.branch_3x3.bn.running_var", "neck.neck1.blocks.bottlenecks.1.cv2.branch_1x1.weight", "neck.neck1.blocks.bottlenecks.1.cv2.branch_1x1.bias", "neck.neck2.blocks.bottlenecks.0.cv1.branch_3x3.conv.weight", "neck.neck2.blocks.bottlenecks.0.cv1.branch_3x3.bn.weight", "neck.neck2.blocks.bottlenecks.0.cv1.branch_3x3.bn.bias", "neck.neck2.blocks.bottlenecks.0.cv1.branch_3x3.bn.running_mean", "neck.neck2.blocks.bottlenecks.0.cv1.branch_3x3.bn.running_var", "neck.neck2.blocks.bottlenecks.0.cv1.branch_1x1.weight", "neck.neck2.blocks.bottlenecks.0.cv1.branch_1x1.bias", "neck.neck2.blocks.bottlenecks.0.cv2.branch_3x3.conv.weight", "neck.neck2.blocks.bottlenecks.0.cv2.branch_3x3.bn.weight", "neck.neck2.blocks.bottlenecks.0.cv2.branch_3x3.bn.bias", "neck.neck2.blocks.bottlenecks.0.cv2.branch_3x3.bn.running_mean", "neck.neck2.blocks.bottlenecks.0.cv2.branch_3x3.bn.running_var", "neck.neck2.blocks.bottlenecks.0.cv2.branch_1x1.weight", "neck.neck2.blocks.bottlenecks.0.cv2.branch_1x1.bias", "neck.neck2.blocks.bottlenecks.1.cv1.branch_3x3.conv.weight", "neck.neck2.blocks.bottlenecks.1.cv1.branch_3x3.bn.weight", "neck.neck2.blocks.bottlenecks.1.cv1.branch_3x3.bn.bias", "neck.neck2.blocks.bottlenecks.1.cv1.branch_3x3.bn.running_mean", "neck.neck2.blocks.bottlenecks.1.cv1.branch_3x3.bn.running_var", "neck.neck2.blocks.bottlenecks.1.cv1.branch_1x1.weight", "neck.neck2.blocks.bottlenecks.1.cv1.branch_1x1.bias", "neck.neck2.blocks.bottlenecks.1.cv2.branch_3x3.conv.weight", "neck.neck2.blocks.bottlenecks.1.cv2.branch_3x3.bn.weight", "neck.neck2.blocks.bottlenecks.1.cv2.branch_3x3.bn.bias", "neck.neck2.blocks.bottlenecks.1.cv2.branch_3x3.bn.running_mean", "neck.neck2.blocks.bottlenecks.1.cv2.branch_3x3.bn.running_var", "neck.neck2.blocks.bottlenecks.1.cv2.branch_1x1.weight", "neck.neck2.blocks.bottlenecks.1.cv2.branch_1x1.bias".
        Unexpected key(s) in state_dict: "backbone.stem.conv.rbr_reparam._input_quantizer._amax", "backbone.stem.conv.rbr_reparam._weight_quantizer._amax", "backbone.stage1.downsample.rbr_reparam._input_quantizer._amax", "backbone.stage1.downsample.rbr_reparam._weight_quantizer._amax", "backbone.stage1.blocks.conv1.conv._input_quantizer._amax", "backbone.stage1.blocks.conv1.conv._weight_quantizer._amax", "backbone.stage1.blocks.conv2.conv._input_quantizer._amax", "backbone.stage1.blocks.conv2.conv._weight_quantizer._amax", "backbone.stage1.blocks.conv3.conv._input_quantizer._amax", "backbone.stage1.blocks.conv3.conv._weight_quantizer._amax", "backbone.stage1.blocks.bottlenecks.0.cv1.rbr_reparam._input_quantizer._amax", "backbone.stage1.blocks.bottlenecks.0.cv1.rbr_reparam._weight_quantizer._amax", "backbone.stage1.blocks.bottlenecks.0.cv2.rbr_reparam._input_quantizer._amax", "backbone.stage1.blocks.bottlenecks.0.cv2.rbr_reparam._weight_quantizer._amax", "backbone.stage1.blocks.bottlenecks.1.cv1.rbr_reparam._input_quantizer._amax", "backbone.stage1.blocks.bottlenecks.1.cv1.rbr_reparam._weight_quantizer._amax", "backbone.stage1.blocks.bottlenecks.1.cv2.rbr_reparam._input_quantizer._amax", "backbone.stage1.blocks.bottlenecks.1.cv2.rbr_reparam._weight_quantizer._amax", "backbone.stage2.downsample.rbr_reparam._input_quantizer._amax", "backbone.stage2.downsample.rbr_reparam._weight_quantizer._amax", "backbone.stage2.blocks.conv1.conv._input_quantizer._amax", "backbone.stage2.blocks.conv1.conv._weight_quantizer._amax", "backbone.stage2.blocks.conv2.conv._input_quantizer._amax", "backbone.stage2.blocks.conv2.conv._weight_quantizer._amax", "backbone.stage2.blocks.conv3.conv._input_quantizer._amax", "backbone.stage2.blocks.conv3.conv._weight_quantizer._amax", "backbone.stage2.blocks.bottlenecks.0.cv1.rbr_reparam._input_quantizer._amax", "backbone.stage2.blocks.bottlenecks.0.cv1.rbr_reparam._weight_quantizer._amax", "backbone.stage2.blocks.bottlenecks.0.cv2.rbr_reparam._input_quantizer._amax", "backbone.stage2.blocks.bottlenecks.0.cv2.rbr_reparam._weight_quantizer._amax", "backbone.stage2.blocks.bottlenecks.1.cv1.rbr_reparam._input_quantizer._amax", "backbone.stage2.blocks.bottlenecks.1.cv1.rbr_reparam._weight_quantizer._amax", "backbone.stage2.blocks.bottlenecks.1.cv2.rbr_reparam._input_quantizer._amax", "backbone.stage2.blocks.bottlenecks.1.cv2.rbr_reparam._weight_quantizer._amax", "backbone.stage2.blocks.bottlenecks.2.cv1.rbr_reparam._input_quantizer._amax", "backbone.stage2.blocks.bottlenecks.2.cv1.rbr_reparam._weight_quantizer._amax", "backbone.stage2.blocks.bottlenecks.2.cv2.rbr_reparam._input_quantizer._amax", "backbone.stage2.blocks.bottlenecks.2.cv2.rbr_reparam._weight_quantizer._amax", "backbone.stage3.downsample.rbr_reparam._input_quantizer._amax", "backbone.stage3.downsample.rbr_reparam._weight_quantizer._amax", "backbone.stage3.blocks.conv1.conv._input_quantizer._amax", "backbone.stage3.blocks.conv1.conv._weight_quantizer._amax", "backbone.stage3.blocks.conv2.conv._input_quantizer._amax", "backbone.stage3.blocks.conv2.conv._weight_quantizer._amax", "backbone.stage3.blocks.conv3.conv._input_quantizer._amax", "backbone.stage3.blocks.conv3.conv._weight_quantizer._amax", "backbone.stage3.blocks.bottlenecks.0.cv1.rbr_reparam._input_quantizer._amax", "backbone.stage3.blocks.bottlenecks.0.cv1.rbr_reparam._weight_quantizer._amax", "backbone.stage3.blocks.bottlenecks.0.cv2.rbr_reparam._input_quantizer._amax", "backbone.stage3.blocks.bottlenecks.0.cv2.rbr_reparam._weight_quantizer._amax", "backbone.stage3.blocks.bottlenecks.1.cv1.rbr_reparam._input_quantizer._amax", "backbone.stage3.blocks.bottlenecks.1.cv1.rbr_reparam._weight_quantizer._amax", "backbone.stage3.blocks.bottlenecks.1.cv2.rbr_reparam._input_quantizer._amax", "backbone.stage3.blocks.bottlenecks.1.cv2.rbr_reparam._weight_quantizer._amax", "backbone.stage3.blocks.bottlenecks.2.cv1.rbr_reparam._input_quantizer._amax", "backbone.stage3.blocks.bottlenecks.2.cv1.rbr_reparam._weight_quantizer._amax", "backbone.stage3.blocks.bottlenecks.2.cv2.rbr_reparam._input_quantizer._amax", "backbone.stage3.blocks.bottlenecks.2.cv2.rbr_reparam._weight_quantizer._amax", "backbone.stage3.blocks.bottlenecks.3.cv1.rbr_reparam._input_quantizer._amax", "backbone.stage3.blocks.bottlenecks.3.cv1.rbr_reparam._weight_quantizer._amax", "backbone.stage3.blocks.bottlenecks.3.cv2.rbr_reparam._input_quantizer._amax", "backbone.stage3.blocks.bottlenecks.3.cv2.rbr_reparam._weight_quantizer._amax", "backbone.stage3.blocks.bottlenecks.4.cv1.rbr_reparam._input_quantizer._amax", "backbone.stage3.blocks.bottlenecks.4.cv1.rbr_reparam._weight_quantizer._amax", "backbone.stage3.blocks.bottlenecks.4.cv2.rbr_reparam._input_quantizer._amax", "backbone.stage3.blocks.bottlenecks.4.cv2.rbr_reparam._weight_quantizer._amax", "backbone.stage4.downsample.rbr_reparam._input_quantizer._amax", "backbone.stage4.downsample.rbr_reparam._weight_quantizer._amax", "backbone.stage4.blocks.conv1.conv._input_quantizer._amax", "backbone.stage4.blocks.conv1.conv._weight_quantizer._amax", "backbone.stage4.blocks.conv2.conv._input_quantizer._amax", "backbone.stage4.blocks.conv2.conv._weight_quantizer._amax", "backbone.stage4.blocks.conv3.conv._input_quantizer._amax", "backbone.stage4.blocks.conv3.conv._weight_quantizer._amax", "backbone.stage4.blocks.bottlenecks.0.cv1.rbr_reparam._input_quantizer._amax", "backbone.stage4.blocks.bottlenecks.0.cv1.rbr_reparam._weight_quantizer._amax", "backbone.stage4.blocks.bottlenecks.0.cv2.rbr_reparam._input_quantizer._amax", "backbone.stage4.blocks.bottlenecks.0.cv2.rbr_reparam._weight_quantizer._amax", "backbone.stage4.blocks.bottlenecks.1.cv1.rbr_reparam._input_quantizer._amax", "backbone.stage4.blocks.bottlenecks.1.cv1.rbr_reparam._weight_quantizer._amax", "backbone.stage4.blocks.bottlenecks.1.cv2.rbr_reparam._input_quantizer._amax", "backbone.stage4.blocks.bottlenecks.1.cv2.rbr_reparam._weight_quantizer._amax", "backbone.context_module.cv1.conv._input_quantizer._amax", "backbone.context_module.cv1.conv._weight_quantizer._amax", "backbone.context_module.cv2.conv._input_quantizer._amax", "backbone.context_module.cv2.conv._weight_quantizer._amax", "neck.neck1.reduce_skip1.conv._input_quantizer._amax", "neck.neck1.reduce_skip1.conv._weight_quantizer._amax", "neck.neck1.reduce_skip2.conv._input_quantizer._amax", "neck.neck1.reduce_skip2.conv._weight_quantizer._amax", "neck.neck1.conv.conv._input_quantizer._amax", "neck.neck1.conv.conv._weight_quantizer._amax", "neck.neck1.upsample._input_quantizer._amax", "neck.neck1.upsample._weight_quantizer._amax", "neck.neck1.downsample.conv._input_quantizer._amax", "neck.neck1.downsample.conv._weight_quantizer._amax", "neck.neck1.reduce_after_concat.conv._input_quantizer._amax", "neck.neck1.reduce_after_concat.conv._weight_quantizer._amax", "neck.neck1.blocks.conv1.conv._input_quantizer._amax", "neck.neck1.blocks.conv1.conv._weight_quantizer._amax", "neck.neck1.blocks.conv2.conv._input_quantizer._amax", "neck.neck1.blocks.conv2.conv._weight_quantizer._amax", "neck.neck1.blocks.conv3.conv._input_quantizer._amax", "neck.neck1.blocks.conv3.conv._weight_quantizer._amax", "neck.neck1.blocks.bottlenecks.0.cv1.rbr_reparam._input_quantizer._amax", "neck.neck1.blocks.bottlenecks.0.cv1.rbr_reparam._weight_quantizer._amax", "neck.neck1.blocks.bottlenecks.0.cv2.rbr_reparam._input_quantizer._amax", "neck.neck1.blocks.bottlenecks.0.cv2.rbr_reparam._weight_quantizer._amax", "neck.neck1.blocks.bottlenecks.1.cv1.rbr_reparam._input_quantizer._amax", "neck.neck1.blocks.bottlenecks.1.cv1.rbr_reparam._weight_quantizer._amax", "neck.neck1.blocks.bottlenecks.1.cv2.rbr_reparam._input_quantizer._amax", "neck.neck1.blocks.bottlenecks.1.cv2.rbr_reparam._weight_quantizer._amax", "neck.neck2.reduce_skip1.conv._input_quantizer._amax", "neck.neck2.reduce_skip1.conv._weight_quantizer._amax", "neck.neck2.reduce_skip2.conv._input_quantizer._amax", "neck.neck2.reduce_skip2.conv._weight_quantizer._amax", "neck.neck2.conv.conv._input_quantizer._amax", "neck.neck2.conv.conv._weight_quantizer._amax", "neck.neck2.upsample._input_quantizer._amax", "neck.neck2.upsample._weight_quantizer._amax", "neck.neck2.downsample.conv._input_quantizer._amax", "neck.neck2.downsample.conv._weight_quantizer._amax", "neck.neck2.reduce_after_concat.conv._input_quantizer._amax", "neck.neck2.reduce_after_concat.conv._weight_quantizer._amax", "neck.neck2.blocks.conv1.conv._input_quantizer._amax", "neck.neck2.blocks.conv1.conv._weight_quantizer._amax", "neck.neck2.blocks.conv2.conv._input_quantizer._amax", "neck.neck2.blocks.conv2.conv._weight_quantizer._amax", "neck.neck2.blocks.conv3.conv._input_quantizer._amax", "neck.neck2.blocks.conv3.conv._weight_quantizer._amax", "neck.neck2.blocks.bottlenecks.0.cv1.rbr_reparam._input_quantizer._amax", "neck.neck2.blocks.bottlenecks.0.cv1.rbr_reparam._weight_quantizer._amax", "neck.neck2.blocks.bottlenecks.0.cv2.rbr_reparam._input_quantizer._amax", "neck.neck2.blocks.bottlenecks.0.cv2.rbr_reparam._weight_quantizer._amax", "neck.neck2.blocks.bottlenecks.1.cv1.rbr_reparam._input_quantizer._amax", "neck.neck2.blocks.bottlenecks.1.cv1.rbr_reparam._weight_quantizer._amax", "neck.neck2.blocks.bottlenecks.1.cv2.rbr_reparam._input_quantizer._amax", "neck.neck2.blocks.bottlenecks.1.cv2.rbr_reparam._weight_quantizer._amax", "neck.neck3.conv.conv._input_quantizer._amax", "neck.neck3.conv.conv._weight_quantizer._amax", "neck.neck3.blocks.conv1.conv._input_quantizer._amax", "neck.neck3.blocks.conv1.conv._weight_quantizer._amax", "neck.neck3.blocks.conv2.conv._input_quantizer._amax", "neck.neck3.blocks.conv2.conv._weight_quantizer._amax", "neck.neck3.blocks.conv3.conv._input_quantizer._amax", "neck.neck3.blocks.conv3.conv._weight_quantizer._amax", "neck.neck3.blocks.bottlenecks.0.cv1.conv._input_quantizer._amax", "neck.neck3.blocks.bottlenecks.0.cv1.conv._weight_quantizer._amax", "neck.neck3.blocks.bottlenecks.0.cv2.conv._input_quantizer._amax", "neck.neck3.blocks.bottlenecks.0.cv2.conv._weight_quantizer._amax", "neck.neck3.blocks.bottlenecks.1.cv1.conv._input_quantizer._amax", "neck.neck3.blocks.bottlenecks.1.cv1.conv._weight_quantizer._amax", "neck.neck3.blocks.bottlenecks.1.cv2.conv._input_quantizer._amax", "neck.neck3.blocks.bottlenecks.1.cv2.conv._weight_quantizer._amax", "neck.neck4.conv.conv._input_quantizer._amax", "neck.neck4.conv.conv._weight_quantizer._amax", "neck.neck4.blocks.conv1.conv._input_quantizer._amax", "neck.neck4.blocks.conv1.conv._weight_quantizer._amax", "neck.neck4.blocks.conv2.conv._input_quantizer._amax", "neck.neck4.blocks.conv2.conv._weight_quantizer._amax", "neck.neck4.blocks.conv3.conv._input_quantizer._amax", "neck.neck4.blocks.conv3.conv._weight_quantizer._amax", "neck.neck4.blocks.bottlenecks.0.cv1.conv._input_quantizer._amax", "neck.neck4.blocks.bottlenecks.0.cv1.conv._weight_quantizer._amax", "neck.neck4.blocks.bottlenecks.0.cv2.conv._input_quantizer._amax", "neck.neck4.blocks.bottlenecks.0.cv2.conv._weight_quantizer._amax", "neck.neck4.blocks.bottlenecks.1.cv1.conv._input_quantizer._amax", "neck.neck4.blocks.bottlenecks.1.cv1.conv._weight_quantizer._amax", "neck.neck4.blocks.bottlenecks.1.cv2.conv._input_quantizer._amax", "neck.neck4.blocks.bottlenecks.1.cv2.conv._weight_quantizer._amax", "heads.head1.stem.seq.conv._input_quantizer._amax", "heads.head1.stem.seq.conv._weight_quantizer._amax", "heads.head1.cls_convs.0.seq.conv._input_quantizer._amax", "heads.head1.cls_convs.0.seq.conv._weight_quantizer._amax", "heads.head1.reg_convs.0.seq.conv._input_quantizer._amax", "heads.head1.reg_convs.0.seq.conv._weight_quantizer._amax", "heads.head1.cls_pred._input_quantizer._amax", "heads.head1.cls_pred._weight_quantizer._amax", "heads.head1.reg_pred._input_quantizer._amax", "heads.head1.reg_pred._weight_quantizer._amax", "heads.head2.stem.seq.conv._input_quantizer._amax", "heads.head2.stem.seq.conv._weight_quantizer._amax", "heads.head2.cls_convs.0.seq.conv._input_quantizer._amax", "heads.head2.cls_convs.0.seq.conv._weight_quantizer._amax", "heads.head2.reg_convs.0.seq.conv._input_quantizer._amax", "heads.head2.reg_convs.0.seq.conv._weight_quantizer._amax", "heads.head2.cls_pred._input_quantizer._amax", "heads.head2.cls_pred._weight_quantizer._amax", "heads.head2.reg_pred._input_quantizer._amax", "heads.head2.reg_pred._weight_quantizer._amax", "heads.head3.stem.seq.conv._input_quantizer._amax", "heads.head3.stem.seq.conv._weight_quantizer._amax", "heads.head3.cls_convs.0.seq.conv._input_quantizer._amax", "heads.head3.cls_convs.0.seq.conv._weight_quantizer._amax", "heads.head3.reg_convs.0.seq.conv._input_quantizer._amax", "heads.head3.reg_convs.0.seq.conv._weight_quantizer._amax", "heads.head3.cls_pred._input_quantizer._amax", "heads.head3.cls_pred._weight_quantizer._amax", "heads.head3.reg_pred._input_quantizer._amax", "heads.head3.reg_pred._weight_quantizer._amax".

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
Cell In[2], line 1
----> 1 model = models.get('yolo_nas_s', checkpoint_path="checkpoints/coco2017_yolo_nas_s_qat/RUN_20240131_184123_827969/ckpt_best.pth", num_classes=80)

File ~/venvs/s-grads/lib/python3.10/site-packages/super_gradients/common/decorators/factory_decorator.py:36, in resolve_param.<locals>.inner.<locals>.wrapper(*args, **kwargs)
     34             new_value = factory.get(args[index])
     35             args = _assign_tuple(args, index, new_value)
---> 36 return func(*args, **kwargs)

File ~/venvs/s-grads/lib/python3.10/site-packages/super_gradients/training/models/model_factory.py:241, in get(model_name, arch_params, num_classes, strict_load, checkpoint_path, pretrained_weights, load_backbone, download_required_code, checkpoint_num_classes, num_input_channels)
    239     load_processing = "processing_params" in ckpt_entries
    240     load_ema_as_net = "ema_net" in ckpt_entries
--> 241     _ = load_checkpoint_to_model(
    242         ckpt_local_path=checkpoint_path,
    243         load_backbone=load_backbone,
    244         net=net,
    245         strict=strict_load,
    246         load_weights_only=True,
    247         load_ema_as_net=load_ema_as_net,
    248         load_processing_params=load_processing,
    249     )
    250 if checkpoint_num_classes != num_classes:
    251     net.replace_head(new_num_classes=num_classes)

File ~/venvs/s-grads/lib/python3.10/site-packages/super_gradients/common/decorators/factory_decorator.py:36, in resolve_param.<locals>.inner.<locals>.wrapper(*args, **kwargs)
     34             new_value = factory.get(args[index])
     35             args = _assign_tuple(args, index, new_value)
---> 36 return func(*args, **kwargs)

File ~/venvs/s-grads/lib/python3.10/site-packages/super_gradients/training/utils/checkpoint_utils.py:1523, in load_checkpoint_to_model(net, ckpt_local_path, load_backbone, strict, load_weights_only, load_ema_as_net, load_processing_params)
   1521     adaptive_load_state_dict(net.backbone, checkpoint, strict)
   1522 else:
-> 1523     adaptive_load_state_dict(net, checkpoint, strict)
   1525 message_suffix = " checkpoint." if not load_ema_as_net else " EMA checkpoint."
   1526 message_model = "model" if not load_backbone else "model's backbone"

File ~/venvs/s-grads/lib/python3.10/site-packages/super_gradients/training/utils/checkpoint_utils.py:102, in adaptive_load_state_dict(net, state_dict, strict, solver)
    100 except (RuntimeError, ValueError, KeyError) as ex:
    101     if strict == StrictLoad.NO_KEY_MATCHING:
--> 102         adapted_state_dict = adapt_state_dict_to_fit_model_layer_names(net.state_dict(), state_dict, solver=solver)
    103         net.load_state_dict(adapted_state_dict["net"], strict=True)
    104     elif strict == StrictLoad.KEY_MATCHING:

File ~/venvs/s-grads/lib/python3.10/site-packages/super_gradients/training/utils/checkpoint_utils.py:1458, in adapt_state_dict_to_fit_model_layer_names(model_state_dict, source_ckpt, exclude, solver)
   1455 if len(exclude):
   1456     model_state_dict = {k: v for k, v in model_state_dict.items() if not any(x in k for x in exclude)}
-> 1458 new_ckpt_dict = solver(model_state_dict, source_ckpt)
   1459 return {"net": new_ckpt_dict}

File ~/venvs/s-grads/lib/python3.10/site-packages/super_gradients/training/utils/checkpoint_utils.py:212, in DefaultCheckpointSolver.__call__(self, model_state_dict, checkpoint_state_dict)
    210 for (ckpt_key, ckpt_val), (model_key, model_val) in zip(checkpoint_state_dict.items(), model_state_dict.items()):
    211     if ckpt_val.shape != model_val.shape:
--> 212         raise ValueError(f"ckpt layer {ckpt_key} with shape {ckpt_val.shape} does not match {model_key}" f" with shape {model_val.shape} in the model")
    213     new_ckpt_dict[model_key] = ckpt_val
    214 return new_ckpt_dict

ValueError: ckpt layer backbone.stem.conv.post_bn.weight with shape torch.Size([48]) does not match backbone.stem.conv.branch_3x3.conv.weight with shape torch.Size([48, 3, 3, 3]) in the model

It would also help if this guide came with the pre and post processing code when doing inference with TensorRT.

I believe I am missing something. There must be a way to load the quantized model right?

Versions

SG version: 3.6.0 PyTorch version: 2.1.2+cu118 Is debug build: False CUDA used to build PyTorch: 11.8 ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.3 LTS (x86_64) GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 Clang version: Could not collect CMake version: Could not collect Libc version: glibc-2.35

Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime) Python platform: Linux-5.15.133.1-microsoft-standard-WSL2-x86_64-with-glibc2.35 Is CUDA available: True CUDA runtime version: 11.8.89 CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3060 Laptop GPU Nvidia driver version: 546.17 cuDNN version: Could not collect HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 39 bits physical, 48 bits virtual Byte Order: Little Endian CPU(s): 12 On-line CPU(s) list: 0-11 Vendor ID: GenuineIntel Model name: 12th Gen Intel(R) Core(TM) i5-12450H CPU family: 6 Model: 154 Thread(s) per core: 2 Core(s) per socket: 6 Socket(s): 1 Stepping: 3 BogoMIPS: 4991.93 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology tsc_reliable nonstop_tsc cpuid pni pclmulqdq vmx ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves avx_vnni umip waitpkg gfni vaes vpclmulqdq rdpid movdiri movdir64b fsrm serialize flush_l1d arch_capabilities Virtualization: VT-x Hypervisor vendor: Microsoft Virtualization type: full L1d cache: 288 KiB (6 instances) L1i cache: 192 KiB (6 instances) L2 cache: 7.5 MiB (6 instances) L3 cache: 12 MiB (1 instance) Vulnerability Gather data sampling: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Retbleed: Mitigation; Enhanced IBRS Vulnerability Spec rstack overflow: Not affected Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected

Versions of relevant libraries: [pip3] numpy==1.23.0 [pip3] onnx==1.13.0 [pip3] onnx-graphsurgeon==0.3.27 [pip3] onnxruntime==1.13.1 [pip3] onnxsim==0.4.35 [pip3] pytorch-quantization==2.1.2 [pip3] torch==2.1.2+cu118 [pip3] torchmetrics==0.8.0 [pip3] torchvision==0.16.2+cu118 [pip3] triton==2.1.0 [conda] Could not collect

nsabir2011 commented 7 months ago

I decided to just save the model object with pickle although I don't like it. I think the weight file should be loadable by SG.

BloodAxe commented 7 months ago

No, saving model as pickle it definitely not the right way to save the model state.

As of now, qat_from_recepie don't have much control on how to export a final model (with or without postprocessing). This feature is lacking unfortunately. This is a good place for improvement, but currently I'm unable to give a time estimate when this may be introduced.