zhiqwang / yolort

yolort is a runtime stack for yolov5 on specialized accelerators such as tensorrt, libtorch, onnxruntime, tvm and ncnn.
https://zhiqwang.com/yolort
GNU General Public License v3.0
717 stars 153 forks source link

Can't load custom trained model #466

Closed errx closed 1 year ago

errx commented 1 year ago

🐛 Describe the bug

I've trained yolov5n model with latest code from https://github.com/ultralytics/yolov5.

However, when I try to load it with yolov5-rt I get the following error:

from yolort.models import YOLOv5
model = YOLOv5.load_from_yolov5(PATH, score_thresh=MIN_CONFIDENCE)

File ~/wr/debug-yolort/venv/lib/python3.10/site-packages/yolort/models/yolov5.py:284, in YOLOv5.load_from_yolov5(cls, checkpoint_path, size, size_divisible, fixed_shape, fill_color, **kwargs)
    259 @classmethod
    260 def load_from_yolov5(
    261     cls,
   (...)
    268     **kwargs: Any,
    269 ):
    270     """
    271     Load custom checkpoints trained from YOLOv5.
    272
   (...)
    282         fill_color (int): fill value for padding. Default: 114
    283     """
--> 284     model = YOLO.load_from_yolov5(checkpoint_path, **kwargs)
    285     yolov5 = cls(
    286         model=model,
    287         size=size,
   (...)
    290         fill_color=fill_color,
    291     )
    292     return yolov5

File ~/wr/debug-yolort/venv/lib/python3.10/site-packages/yolort/models/yolo.py:204, in YOLO.load_from_yolov5(cls, checkpoint_path, score_thresh, nms_thresh, version, post_process)
    185 @classmethod
    186 def load_from_yolov5(
    187     cls,
   (...)
    192     post_process: Optional[nn.Module] = None,
    193 ):
    194     """
    195     Load model state from the checkpoint trained by YOLOv5.
    196
   (...)
    202             values are ["r3.1", "r4.0", "r6.0"]. Default: "r6.0".
    203     """
--> 204     model_info = load_from_ultralytics(checkpoint_path, version=version)
    205     backbone_name = f"darknet_{model_info['size']}_{version.replace('.', '_')}"
    206     depth_multiple = model_info["depth_multiple"]

File ~/wr/debug-yolort/venv/lib/python3.10/site-packages/yolort/models/_checkpoint.py:33, in load_from_ultralytics(checkpoint_path, version)
     27 if version not in ["r3.1", "r4.0", "r6.0"]:
     28     raise NotImplementedError(
     29         f"Currently does not support version: {version}. Feel free to file an issue "
     30         "labeled enhancement to us."
     31     )
---> 33 checkpoint_yolov5 = load_yolov5_model(checkpoint_path)
     34 num_classes = checkpoint_yolov5.yaml["nc"]
     35 strides = checkpoint_yolov5.stride

File ~/wr/debug-yolort/venv/lib/python3.10/site-packages/yolort/v5/helper.py:67, in load_yolov5_model(checkpoint_path, fuse)
     50 """
     51 Creates a specified YOLOv5 model.
     52
   (...)
     63     YOLOv5 pytorch model
     64 """
     66 with add_yolov5_context():
---> 67     ckpt = torch.load(attempt_download(checkpoint_path), map_location=torch.device("cpu"))
     68     if fuse:
     69         model = ckpt["ema" if ckpt.get("ema") else "model"].float().fuse().eval()

File ~/wr/debug-yolort/venv/lib/python3.10/site-packages/torch/serialization.py:789, in load(f, map_location, pickle_module, weights_only, **pickle_load_args)
    787             except RuntimeError as e:
    788                 raise pickle.UnpicklingError(UNSAFE_MESSAGE + str(e)) from None
--> 789         return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
    790 if weights_only:
    791     try:

File ~/wr/debug-yolort/venv/lib/python3.10/site-packages/torch/serialization.py:1131, in _load(zip_file, map_location, pickle_module, pickle_file, **pickle_load_args)
   1129 unpickler = UnpicklerWrapper(data_file, **pickle_load_args)
   1130 unpickler.persistent_load = persistent_load
-> 1131 result = unpickler.load()
   1133 torch._utils._validate_loaded_sparse_tensors()
   1135 return result

File ~/wr/debug-yolort/venv/lib/python3.10/site-packages/torch/serialization.py:1124, in _load.<locals>.UnpicklerWrapper.find_class(self, mod_name, name)
   1122         pass
   1123 mod_name = load_module_mapping.get(mod_name, mod_name)
-> 1124 return super().find_class(mod_name, name)

AttributeError: Can't get attribute 'DetectionModel' on <module 'models.yolo' from '/home/err/wr/debug-yolort/venv/lib/python3.10/site-packages/yolort/v5/models/yolo.py'>

I guess it's related to https://github.com/ultralytics/yolov5/issues/9151

But I'm not sure what should I do. Thank you.

Versions

PyTorch version: 1.13.0+cu117 Is debug build: False CUDA used to build PyTorch: 11.7 ROCM used to build PyTorch: N/A

OS: Arch Linux (x86_64) GCC version: (GCC) 12.2.0 Clang version: 14.0.6 CMake version: version 3.24.3 Libc version: glibc-2.36

Python version: 3.10.8 (main, Nov 1 2022, 14:18:21) [GCC 12.2.0] (64-bit runtime) Python platform: Linux-6.0.8-arch1-1-x86_64-with-glibc2.36 Is CUDA available: True CUDA runtime version: Could not collect CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA GeForce GTX 1080 Nvidia driver version: 520.56.06 cuDNN version: Could not collect HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

Versions of relevant libraries: [pip3] numpy==1.23.4 [pip3] torch==1.13.0 [pip3] torchvision==0.14.0 [conda] Could not collect

errx commented 1 year ago
diff --git a/yolort/v5/models/yolo.py b/yolort/v5/models/yolo.py
index 38ae7a3..ae75205 100644
--- a/yolort/v5/models/yolo.py
+++ b/yolort/v5/models/yolo.py
@@ -38,7 +38,7 @@ from .experimental import CrossConv, MixConv2d
 if is_module_available("thop"):
     import thop  # for FLOPs computation

-__all__ = ["Model", "Detect"]
+__all__ = ["Model", "Detect", "DetectionModel"]

 LOGGER = logging.getLogger(__name__)

@@ -336,3 +336,5 @@ def parse_model(d, ch):  # model_dict, input_channels(3)
             ch = []
         ch.append(c2)
     return nn.Sequential(*layers), sorted(save)
+
+DetectionModel = Model

this hack seems to be working

zhiqwang commented 1 year ago

Thanks for reporting this bug to us @errx , and seems that your strategy is correct, we do not support YOLOv5's v6.2 or master branch at this time. And it would be great if you could make a Pull Request for this change to us.