lyuwenyu / RT-DETR

[CVPR 2024] Official RT-DETR (RTDETR paddle pytorch), Real-Time DEtection TRansformer, DETRs Beat YOLOs on Real-time Object Detection. 🔥 🔥 🔥
Apache License 2.0
1.64k stars 178 forks source link

Converting to coreml #341

Open ramonhollands opened 1 week ago

ramonhollands commented 1 week ago

When exporting to coreml, I encounter the following error:

ValueError: In op, of type linear, named out_w, the named input weight must have the same data type as the named input x. However, weight has dtype fp32 whereas x has dtype int32.

Any clues where to look?

Complete trace:

Converting PyTorch Frontend ==> MIL Ops: 28%|█████ | 456/1610 [00:00<00:00, 4297.43 ops/s] Traceback (most recent call last): File "/RT-DETR/rtdetr_pytorch/tools/export_onnx.py", line 237, in main(args) File "/RT-DETR/rtdetr_pytorch/tools/export_onnx.py", line 76, in main mlmodel = ct.convert( File "/opt/conda/lib/python3.10/site-packages/coremltools/converters/_converters_entry.py", line 581, in convert mlmodel = mil_convert( File "/opt/conda/lib/python3.10/site-packages/coremltools/converters/mil/converter.py", line 188, in mil_convert return _mil_convert(model, convert_from, convert_to, ConverterRegistry, MLModel, compute_units, kwargs) File "/opt/conda/lib/python3.10/site-packages/coremltools/converters/mil/converter.py", line 212, in _mil_convert proto, mil_program = mil_convert_to_proto( File "/opt/conda/lib/python3.10/site-packages/coremltools/converters/mil/converter.py", line 288, in mil_convert_to_proto prog = frontend_converter(model, kwargs) File "/opt/conda/lib/python3.10/site-packages/coremltools/converters/mil/converter.py", line 108, in call return load(*args, kwargs) File "/opt/conda/lib/python3.10/site-packages/coremltools/converters/mil/frontend/torch/load.py", line 82, in load return _perform_torch_convert(converter, debug) File "/opt/conda/lib/python3.10/site-packages/coremltools/converters/mil/frontend/torch/load.py", line 116, in _perform_torch_convert prog = converter.convert() File "/opt/conda/lib/python3.10/site-packages/coremltools/converters/mil/frontend/torch/converter.py", line 581, in convert convert_nodes(self.context, self.graph) File "/opt/conda/lib/python3.10/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 86, in convert_nodes raise e # re-raise exception File "/opt/conda/lib/python3.10/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 81, in convert_nodes convert_single_node(context, node) File "/opt/conda/lib/python3.10/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 134, in convert_single_node add_op(context, node) File "/opt/conda/lib/python3.10/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 879, in matmul res = mb.linear(x=linear_x, weight=transposed_weight, name=node.name) File "/opt/conda/lib/python3.10/site-packages/coremltools/converters/mil/mil/ops/registry.py", line 182, in add_op return cls._add_op(op_cls_to_add, kwargs) File "/opt/conda/lib/python3.10/site-packages/coremltools/converters/mil/mil/builder.py", line 182, in _add_op new_op = op_cls(**kwargs) File "/opt/conda/lib/python3.10/site-packages/coremltools/converters/mil/mil/operation.py", line 191, in init self._validate_and_set_inputs(input_kv) File "/opt/conda/lib/python3.10/site-packages/coremltools/converters/mil/mil/operation.py", line 504, in _validate_and_set_inputs self.input_spec.validate_inputs(self.name, self.op_type, input_kvs) File "/opt/conda/lib/python3.10/site-packages/coremltools/converters/mil/mil/input_type.py", line 137, in validate_inputs raise ValueError(msg) ValueError: In op, of type linear, named out_w, the named input weight must have the same data type as the named input x. However, weight has dtype fp32 whereas x has dtype int32.

Code used in export_onnx:

    class Model(nn.Module):
        def __init__(self, ) -> None:
            super().__init__()
            self.model = cfg.model.deploy()
            self.postprocessor = cfg.postprocessor.deploy()

        def forward(self, images):
            outputs = self.model(images)
            orig_target_sizes = torch.tensor([image_width, image_height], dtype=torch.float32)
            orig_target_sizes = orig_target_sizes.to(images.device)
            return self.postprocessor(outputs, orig_target_sizes)

....

    try:
        import coremltools
    except ImportError:
        print("coremltools is not installed. Installing now...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", "coremltools"])
        import coremltools
        print("coremltools has been installed successfully.")

    model = Model()
    model.eval()

    example_input = (torch.rand(1, 3, image_height, image_width, dtype=torch.float32)),
    traced_model = torch.jit.trace(model, example_input)

    import coremltools as ct
    example_images = torch.rand(1, 3, image_height, image_width, dtype=torch.float32)  # Example input size for images
    input_images = ct.ImageType(name="images", shape=example_images.shape)

    # Convert the model
    mlmodel = ct.convert(
        traced_model,
        inputs=[input_images]
    )

    # Save the Core ML model
    mlmodel.save("best.mlmodel")
lyuwenyu commented 1 week ago

File "/opt/conda/lib/python3.10/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 879, in matmul res = mb.linear(x=linear_x, weight=transposed_weight, name=node.name)

不太确定具体是那个Linear导致的 二分法排查一下?