apple / coremltools

Core ML tools contain supporting tools for Core ML model conversion, editing, and validation.
https://coremltools.readme.io
BSD 3-Clause "New" or "Revised" License
4.42k stars 640 forks source link

Complex op such as `irfftn` doesn't support dynamic shapes #1957

Open chophilip21 opened 1 year ago

chophilip21 commented 1 year ago

🐞Describing the bug

Stack Trace

DEBUG:coremltools:Adding const op '256_end_0'
INFO:coremltools:Adding op '256_end_0' of type const
DEBUG:coremltools:Downcast const op 256_end_0 dataint64 as int32
DEBUG:coremltools:Downcast const op 256_end_0 dataint64 as int32
DEBUG:coremltools:Adding const op '256_end_mask_0'
INFO:coremltools:Adding op '256_end_mask_0' of type const
DEBUG:coremltools:Adding const op '256_squeeze_mask_0'
INFO:coremltools:Adding op '256_squeeze_mask_0' of type const
INFO:coremltools:Converting op ffted3.3 : complex
INFO:coremltools:Adding op 'complex_0' of type complex
Converting PyTorch Frontend ==> MIL Ops:   5%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–                                                                                                                                                        | 152/3297 [00:00<00:04, 650.29 ops/s]
Traceback (most recent call last):
  File "/Users/philip/SkyCoreML/main.py", line 51, in <module>
    main()
  File "/Users/philip/SkyCoreML/main.py", line 43, in main
    lama.convert(args)
  File "/Users/philip/SkyCoreML/src/skycoreml/conversion/lama.py", line 124, in convert
    coreml_model = ct.convert(
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/_converters_entry.py", line 530, in convert
    mlmodel = mil_convert(
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/converter.py", line 188, in mil_convert
    return _mil_convert(model, convert_from, convert_to, ConverterRegistry, MLModel, compute_units, **kwargs)
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/converter.py", line 212, in _mil_convert
    proto, mil_program = mil_convert_to_proto(
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/converter.py", line 286, in mil_convert_to_proto
    prog = frontend_converter(model, **kwargs)
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/converter.py", line 108, in __call__
    return load(*args, **kwargs)
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/load.py", line 63, in load
    return _perform_torch_convert(converter, debug)
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/load.py", line 102, in _perform_torch_convert
    prog = converter.convert()
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/converter.py", line 439, in convert
    convert_nodes(self.context, self.graph)
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 92, in convert_nodes
    add_op(context, node)
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 5691, in complex
    result = mb.complex(real_data=real_part, imag_data=imag_part)
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/mil/ops/registry.py", line 183, in add_op
    return cls._add_op(op_cls_to_add, **kwargs)
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/mil/builder.py", line 182, in _add_op
    new_op.type_value_inference()
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/mil/operation.py", line 253, in type_value_inference
    output_types = self.type_inference()
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py", line 162, in type_inference
    raise ValueError(
ValueError: The shape of real_data ((1, 192, is32, is33)) and imag_data ((1, 192, is34, is35)) must match to construct complex data.

To Reproduce

To reproduce, clone the above repository for CoreLama, and run convert_lama.py. The only thing you would want to change is giving some kind of flexible shape to the image input and mask input, so something like:

   image_shape = ct.EnumeratedShapes(
        shapes=[[1, 3, 256, 256], [1, 3, 512, 512], [1, 3, 1024, 1024]],
        default=[1, 3, 512, 512],
    )

    mask_shape = ct.EnumeratedShapes(
        shapes=[[1, 1, 256, 256], [1, 1, 512, 512], [1, 1, 1024, 1024]],
        default=[1, 1, 512, 512],
    )

  coreml_model = ct.convert(
      jit_model,
      convert_to="mlprogram",
      compute_precision=ct.precision.FLOAT32,
      compute_units=ct.ComputeUnit.CPU_AND_GPU,
      inputs=[
          ct.ImageType(name="image",
                       shape=image_shape,
                       scale=1/255.0),
          ct.ImageType(
              name="mask",
              shape=mask_shape,
              color_layout=ct.colorlayout.GRAYSCALE)
      ],
      outputs=[ct.ImageType(name="output")],
      skip_model_load=True
  )

Things I have tried

Doesn't matter if you are tracing or scripting, it will fail as long as you enable flexible shape. It should produce mlpackage properly. I have tried multiple versions of Pytorch and CoreML tools, but no luck. Also tried going back to legacy neuralnetwork option, but below creates a model file that seems to have some kind of memory leak.

        # see if you can update the spec
        spec = ct.utils.load_spec(coreml_model_file_name)
        image = spec.description.input[0].name
        mask = spec.description.input[1].name

        # update the image
        flexible_shape_utils.set_multiarray_ndshape_range(
            spec,
            feature_name=image,
            lower_bounds=[1, 3, 256, 256],
            upper_bounds=[1, 3, 1024, 1024],
        )

        # update the mask
        flexible_shape_utils.set_multiarray_ndshape_range(
            spec,
            feature_name=mask,
            lower_bounds=[1, 1, 256, 256],
            upper_bounds=[1, 1, 1024, 1024],
        )

        # save the spec
        coreml_model_file_name = "LaMa_updated.mlmodel"
        coreml_model_file_name = os.path.join(save_dir, coreml_model_file_name)
        ct.utils.save_spec(spec, coreml_model_file_name)

System environment (please complete the following information):

Additional context

junpeiz commented 1 year ago

Hi @chophilip21, thank you for reporting this issue with detailed info!

You are right, when checking shape in complex op, it should also consider dynamic/symbolic shape.

A quick fix on your side would be to change "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py", line 162, in type_inference:

from coremltools.converters.mil.mil.types.symbolic import any_symbolic, is_symbolic

...

def type_inference(self):
  # Don't compare the shape directly if there is symbolic shape.
  if any_symbolic(self.real_data.shape) or any_symbolic(self.imag_data.shape):
        # Checking the non-symbolic dim to make sure they match.
        for dim, dim_size in enumerate(self.real_data.shape):
            if not is_symbolic(dim_size):
                assert dim_size == self.imag_data.shape[dim]
  else:
        # Here is the original shape checking logic.
        if self.real_data.shape != self.imag_data.shape:
            ...

Notice it's just a draft, but should explained the logic. Thanks!

chophilip21 commented 1 year ago

@junpeiz

Thank you very much for your help!

 def type_inference(self):
        # Don't compare the shape directly if there is symbolic shape.
        if any_symbolic(self.real_data.shape) or any_symbolic(self.imag_data.shape):
        # Checking the non-symbolic dim to make sure they match.
            for dim, dim_size in enumerate(self.real_data.shape):
                if not is_symbolic(dim_size):
                    assert dim_size == self.imag_data.shape[dim]
        else:
            if self.real_data.shape != self.imag_data.shape:
                raise ValueError(
                    f"The shape of real_data ({self.real_data.shape}) and imag_data "
                    f"({self.imag_data.shape}) must match to construct complex data."
                )
        return types.tensor(
            infer_complex_dtype(self.real_data.dtype, self.imag_data.dtype),
            self.real_data.shape,
        )

I can bypass the check there, but the conversion unfortunately does not get far after that point. You get another complex data issue here.

DEBUG:coremltools:Adding const op 'gather_6_axis_0'
INFO:coremltools:Adding op 'gather_6_axis_0' of type const
INFO:coremltools:Converting op 259 : listconstruct
INFO:coremltools:Converting op 260 : listconstruct
INFO:coremltools:Adding op '260' of type const
INFO:coremltools:Converting op output.3 : fft_irfftn
INFO:coremltools:Adding op 'complex_irfftn_0' of type complex_irfftn
Converting PyTorch Frontend ==> MIL Ops:   5%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š                                                                                                                     | 156/3294 [00:00<00:01, 1915.99 ops/s]
Traceback (most recent call last):
  File "/Users/philip/SkyCoreML/main.py", line 40, in <module>
    main()
  File "/Users/philip/SkyCoreML/main.py", line 35, in main
    lama.convert(args)
  File "/Users/philip/SkyCoreML/src/skycoreml/conversion/lama.py", line 136, in convert
    coreml_model = ct.convert(
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/_converters_entry.py", line 551, in convert
    mlmodel = mil_convert(
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/converter.py", line 188, in mil_convert
    return _mil_convert(model, convert_from, convert_to, ConverterRegistry, MLModel, compute_units, **kwargs)
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/converter.py", line 212, in _mil_convert
    proto, mil_program = mil_convert_to_proto(
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/converter.py", line 286, in mil_convert_to_proto
    prog = frontend_converter(model, **kwargs)
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/converter.py", line 108, in __call__
    return load(*args, **kwargs)
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/load.py", line 75, in load
    return _perform_torch_convert(converter, debug)
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/load.py", line 114, in _perform_torch_convert
    prog = converter.convert()
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/converter.py", line 481, in convert
    convert_nodes(self.context, self.graph)
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 93, in convert_nodes
    add_op(context, node)
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 5893, in fft_irfftn
    irfftn_res = mb.complex_irfftn(data=input_data, shapes=shapes, dims=dims, norm=norm)
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/mil/ops/registry.py", line 183, in add_op
    return cls._add_op(op_cls_to_add, **kwargs)
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/mil/builder.py", line 164, in _add_op
    kwargs.update(cls._create_vars(
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/mil/builder.py", line 147, in _create_vars
    var = cls._add_const(val, new_var_name, before_op)
  File "/Users/philip/SkyCoreML/env/lib/python3.9/site-packages/coremltools/converters/mil/mil/builder.py", line 76, in _add_const
    raise ValueError("Cannot add const {}".format(val))
ValueError: Cannot add const [<coremltools.converters.mil.mil.var.Var object at 0x16b983700>, <coremltools.converters.mil.mil.var.Var object at 0x16b9838e0>]

And above is pointing at:

@register_torch_op
def fft_irfftn(context, node):
    """Lowers torch.fft.irfftn by the dialect op `complex_irfftn` from complex_dialect_ops.py."""
    input_data, shapes, dims, norm = _get_inputs(context, node, expected=[4])
    irfftn_res = mb.complex_irfftn(data=input_data, shapes=shapes, dims=dims, norm=norm)
    context.add(irfftn_res, node.name)

This may be because we are just trying to bypass the checks above. Or could there be another reason for this?

junpeiz commented 1 year ago

@chophilip21 You are right, the issue is that the input shapes to mb.complex_irfftn is dynamic. However, in the op definition in coremltools/coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py, the shapes needs to be a const, and that's why the error says "Cannot add const".

Here is a minimum example to reproduce the issue (so you can also try it on your end to avoid debugging a large llama model), which could be placed in coremltools/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py:

    @pytest.mark.parametrize(
        "compute_unit, backend",
        itertools.product(compute_units, backends),
    )
    def test_fftn_dynamic_shape(
        self, compute_unit: ct.ComputeUnit, backend
    ):
        class FftnModel(torch.nn.Module):
            def forward(self, x, y):
                x = torch.complex(x, x)
                return torch.fft.irfftn(x, s=y.shape, dim=None, norm=None)

        input_data = [torch.rand(2, 3, 4), torch.rand(1, 4)]
        input_type = [
            ct.TensorType(shape=(2, 3, RangeDim(1, 10))),
            ct.TensorType(shape=(RangeDim(1, 10), RangeDim(1, 10))),
        ]
        TorchBaseTest.run_compare_torch(
            input_data, FftnModel(), backend=backend, compute_unit=compute_unit, input_as_shape=False, converter_input_type=input_type,
        )

We will use this thread as a feature request for "Supporting dynamic shapes in complex irfftn op". Thank you for reporting this issue!

chophilip21 commented 1 year ago

Thank you, I look forward to hearing updates on this!

StevenSK-king commented 1 year ago

Hello everyone, how is the requirement? Does coreml supports dynamic shapes already?

StevenSK-king commented 11 months ago

Any update on this issue?