apple / coremltools

Core ML tools contain supporting tools for Core ML model conversion, editing, and validation.
https://coremltools.readme.io
BSD 3-Clause "New" or "Revised" License
4.32k stars 627 forks source link

Feature request for the implementation of variance, standard deviation, and other statistical MIL Ops #1203

Open RahulBhalley opened 3 years ago

RahulBhalley commented 3 years ago

🌱 Describe your Feature Request

Use cases

Describe alternatives you've considered

I have considered implementing my own function for calculating variance. My implementation is available in issue #1202's discussion comments.

But then, not every developer will be able to implement these (or other) tensor operations correctly. Furthermore, I think these are very obvious statistical operations that a user might expect to be available right out of the box because commonly used frameworks like PyTorch and TensorFlow provide these.

Additional context

N/A

RahulBhalley commented 3 years ago

Hi @TobyRoseman, is there update on this one? It's been more than 2 months since I opened this issue. I need these operations to make my CoreML model. Do you guys have any timeline?

RahulBhalley commented 2 years ago

One workaround I'm thinking of is creating a Custom Operator but it (or any) doesn't get registered in Core ML for PyTorch. Why? @TobyRoseman Can you please help me with this? It's really frustrating.

kory commented 2 years ago

Hi @RahulBhalley, I've written a composite operator for this operation, computing population variance (unbiased=False for the torch op). Let me know if this helps.

@register_torch_op()
def var(context, node):
    inputs = _get_inputs(context, node, expected=4)
    x = inputs[0]
    axes = inputs[1].val

    # Assert we can have biased divisor (N).
    # An unbiased divisor (N - 1) would be much more complex,
    # since we can't use reduce_mean. We therefore would need
    # to otherwise do manual computation of the divisor for each axis.
    assert(input[2].val == False)

    keepdim = inputs[3].val

    x_mean = mb.reduce_mean(x = x, axes = axes, keep_dims=keepdim)
    x_sub_mean = mb.sub(x = x, y = x_mean)
    x_sub_mean_square = mb.square(x = x_sub_mean)
    x_var = mb.reduce_mean(x = x_sub_mean_square, axes = axes, keep_dims=keepdim)

    context.add(x_var, torch_name=node.name)
RahulBhalley commented 2 years ago

@kory thanks for implementation but it doesn't work. For input tensor of shape torch.Size([1, 512, 64, 64]) I get the following error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/var/folders/v8/72133fq53mngvq334v2xn0w00000gn/T/ipykernel_5317/1319924930.py in <module>
----> 1 mlmodel = ct.convert(
      2     traced_model,
      3     inputs=[ct.TensorType(name="inputA", shape=image_shape),
      4             ct.TensorType(name="inputB", shape=image_shape)]
      5 )

/usr/local/lib/python3.9/site-packages/coremltools/converters/_converters_entry.py in convert(model, source, inputs, outputs, classifier_config, minimum_deployment_target, convert_to, compute_precision, skip_model_load, compute_units, useCPUOnly, package_dir, debug)
    324         raise ValueError("Invalid value of the argument 'compute_precision'")
    325 
--> 326     mlmodel = mil_convert(
    327         model,
    328         convert_from=exact_source,

/usr/local/lib/python3.9/site-packages/coremltools/converters/mil/converter.py in mil_convert(model, convert_from, convert_to, compute_units, **kwargs)
    180         See `coremltools.converters.convert`
    181     """
--> 182     return _mil_convert(model, convert_from, convert_to, ConverterRegistry, MLModel, compute_units, **kwargs)
    183 
    184 

/usr/local/lib/python3.9/site-packages/coremltools/converters/mil/converter.py in _mil_convert(model, convert_from, convert_to, registry, modelClass, compute_units, **kwargs)
    207         _os.chmod(weights_dir, _stat.S_IRWXU | _stat.S_IRWXG | _stat.S_IRWXO)
    208 
--> 209     proto, mil_program = mil_convert_to_proto(
    210                             model,
    211                             convert_from,

/usr/local/lib/python3.9/site-packages/coremltools/converters/mil/converter.py in mil_convert_to_proto(model, convert_from, convert_to, converter_registry, **kwargs)
    298     frontend_converter = frontend_converter_type()
    299 
--> 300     prog = frontend_converter(model, **kwargs)
    301 
    302     if convert_to.lower() != "neuralnetwork":

/usr/local/lib/python3.9/site-packages/coremltools/converters/mil/converter.py in __call__(self, *args, **kwargs)
    102         from .frontend.torch import load
    103 
--> 104         return load(*args, **kwargs)
    105 
    106 

/usr/local/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/load.py in load(model_spec, debug, **kwargs)
     48     cut_at_symbols = kwargs.get("cut_at_symbols", None)
     49     converter = TorchConverter(torchscript, inputs, outputs, cut_at_symbols)
---> 50     return _perform_torch_convert(converter, debug)
     51 
     52 

/usr/local/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/load.py in _perform_torch_convert(converter, debug)
     85 def _perform_torch_convert(converter, debug):
     86     try:
---> 87         prog = converter.convert()
     88     except RuntimeError as e:
     89         if debug and "convert function" in str(e):

/usr/local/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/converter.py in convert(self)
    237 
    238             # Add the rest of the operations
--> 239             convert_nodes(self.context, self.graph)
    240 
    241             graph_outputs = [self.context[name] for name in self.graph.outputs]

/usr/local/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/ops.py in convert_nodes(context, graph)
     74                 "PyTorch convert function for op '{}' not implemented.".format(node.kind)
     75             )
---> 76         add_op(context, node)
     77 
     78         # We've generated all the outputs the graph needs, terminate conversion.

/var/folders/v8/72133fq53mngvq334v2xn0w00000gn/T/ipykernel_5317/1933354594.py in var(context, node)
     14 
     15     x_mean = mb.reduce_mean(x = x, axes = axes, keep_dims=keepdim)
---> 16     x_sub_mean = mb.sub(x = x, y = x_mean)
     17     x_sub_mean_square = mb.square(x = x_sub_mean)
     18     x_var = mb.reduce_mean(x = x_sub_mean_square, axes = axes, keep_dims=keepdim)

/usr/local/lib/python3.9/site-packages/coremltools/converters/mil/mil/ops/registry.py in add_op(cls, **kwargs)
     61             @classmethod
     62             def add_op(cls, **kwargs):
---> 63                 return cls._add_op(op_cls, **kwargs)
     64 
     65             setattr(Builder, op_type, add_op)

/usr/local/lib/python3.9/site-packages/coremltools/converters/mil/mil/builder.py in _add_op(cls, op_cls, **kwargs)
    189         curr_block()._insert_op_before(new_op, before_op=before_op)
    190         new_op.build_nested_blocks()
--> 191         new_op.type_value_inference()
    192         if len(new_op.outputs) == 1:
    193             return new_op.outputs[0]

/usr/local/lib/python3.9/site-packages/coremltools/converters/mil/mil/operation.py in type_value_inference(self, overwrite_output)
    238         existing _output_vars
    239         """
--> 240         output_types = self.type_inference()
    241         if not isinstance(output_types, tuple):
    242             output_types = (output_types,)

/usr/local/lib/python3.9/site-packages/coremltools/converters/mil/mil/ops/defs/elementwise_binary.py in type_inference(self)
     49         shapea = list(typea.get_shape())
     50         shapeb = list(typeb.get_shape())
---> 51         ret_shape = broadcast_shapes(shapea, shapeb)
     52         return types.tensor(primitive_type, ret_shape)
     53 

/usr/local/lib/python3.9/site-packages/coremltools/converters/mil/mil/ops/defs/_utils.py in broadcast_shapes(shape_x, shape_y)
     37         elif not y_unknown and shape_y[i] > 1:
     38             if not x_unknown and shape_x[i] != shape_y[i]:
---> 39                 raise ValueError(
     40                     "Incompatible dim {} in shapes {} vs. {}".format(
     41                         i, shape_x, shape_y

ValueError: Incompatible dim 2 in shapes (1, 512, 4096) vs. (1, 1, 512)
akshaydevelops commented 5 months ago

@RahulBhalley have you found any solution for this? I'm stuck with this problem while trying to convert wespeaker model like this

I tried this possible fix, but no luck.

RahulBhalley commented 5 months ago

@akshaydevelops try this.

Disclaimer: I haven't tried it but it should work.

@register_torch_op()
def var(context, node):
    inputs = _get_inputs(context, node, expected=4)
    x = inputs[0]
    axes = inputs[1].val

    # Ensure axes is a list of integers, e.g., [2, 3] for the last two dimensions (Change #3)
    assert isinstance(axes, list) and all(isinstance(axis, int) for axis in axes)

    # Assert we can have biased divisor (N). (Change #1)
    assert(inputs[2].val == False)

    keepdim = True  # Set keepdim to True for broadcasting (Change #2)

    x_mean = mb.reduce_mean(x=x, axes=axes, keep_dims=keepdim)
    x_sub_mean = mb.sub(x=x, y=x_mean)  # Broadcasting should work here (Change #4)
    x_sub_mean_square = mb.square(x=x_sub_mean)
    x_var = mb.reduce_mean(x=x_sub_mean_square, axes=axes, keep_dims=keepdim)

    context.add(x_var, torch_name=node.name)

Btw if you're not looking into using composite operation then I think my PyTorch function from scratch for variance (https://github.com/apple/coremltools/issues/1202#issuecomment-854640610) can be traced & converted to CoreML. Just replace torch.var() in your code with that var() (https://github.com/apple/coremltools/issues/1202#issuecomment-854640610) function and tracing + conversion should work fine.

akshaydevelops commented 4 months ago

@RahulBhalley I have tried the above code here.

wespeaker/bin/export_coreml.py

I am facing the below error

(wespeaker) akshayreddy@akshayreddy bin % python3 export_coreml.py --config /Users/akshayreddy/repos/source/ML/model/config.yaml --checkpoint /Users/akshayreddy/repos/source/ML/model/wespeaker.pt --output_file /Users/akshayreddy/repos/source/ML/model/wespeaker_ml.mlmodel

model Loaded successfully
WARNING:root:unexpected tensor: projection.weight
/opt/anaconda3/envs/wespeaker/lib/python3.9/site-packages/wespeaker/models/resnet.py:189: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  return torch.tensor(0.0), embed_a
WARNING:coremltools:Tuple detected at graph output. This will be flattened in the converted model.
Converting PyTorch Frontend ==> MIL Ops:  99%|â–‰| 1692/1710 [00:00<00:00, 5553.04
Traceback (most recent call last):
  File "/Users/akshayreddy/repos/source/ML/wespeaker/wespeaker/bin/export_coreml.py", line 90, in <module>
    main()
  File "/Users/akshayreddy/repos/source/ML/wespeaker/wespeaker/bin/export_coreml.py", line 80, in main
    mlmodel = ct.convert(
  File "/opt/anaconda3/envs/wespeaker/lib/python3.9/site-packages/coremltools/converters/_converters_entry.py", line 574, in convert
    mlmodel = mil_convert(
  File "/opt/anaconda3/envs/wespeaker/lib/python3.9/site-packages/coremltools/converters/mil/converter.py", line 188, in mil_convert
    return _mil_convert(model, convert_from, convert_to, ConverterRegistry, MLModel, compute_units, **kwargs)
  File "/opt/anaconda3/envs/wespeaker/lib/python3.9/site-packages/coremltools/converters/mil/converter.py", line 212, in _mil_convert
    proto, mil_program = mil_convert_to_proto(
  File "/opt/anaconda3/envs/wespeaker/lib/python3.9/site-packages/coremltools/converters/mil/converter.py", line 286, in mil_convert_to_proto
    prog = frontend_converter(model, **kwargs)
  File "/opt/anaconda3/envs/wespeaker/lib/python3.9/site-packages/coremltools/converters/mil/converter.py", line 108, in __call__
    return load(*args, **kwargs)
  File "/opt/anaconda3/envs/wespeaker/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/load.py", line 80, in load
    return _perform_torch_convert(converter, debug)
  File "/opt/anaconda3/envs/wespeaker/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/load.py", line 99, in _perform_torch_convert
    prog = converter.convert()
  File "/opt/anaconda3/envs/wespeaker/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/converter.py", line 519, in convert
    convert_nodes(self.context, self.graph)
  File "/opt/anaconda3/envs/wespeaker/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 88, in convert_nodes
    add_op(context, node)
  File "/Users/akshayreddy/repos/source/ML/wespeaker/wespeaker/bin/export_coreml.py", line 38, in var
    assert isinstance(axes, list) and all(
AssertionError
RahulBhalley commented 4 months ago

Sorry, I don't understand this error trace. Apple team can help here.

akshaydevelops commented 4 months ago

this is the assertion from the function above

@register_torch_op()
def var(context, node):
assert isinstance(axes, list) and all(isinstance(axis, int) for axis in axes)
akshaydevelops commented 4 months ago

@RahulBhalley This code works for me.


@register_torch_op()
def var(context, node):
    inputs = _get_inputs(context, node, expected=4)

    x = inputs[0]
    axes = inputs[1].val
    keepdim = True

    x_mean = mb.reduce_mean(x=x, axes=axes, keep_dims=keepdim)
    x_sub_mean = mb.sub(x=x, y=x_mean)
    x_sub_mean_square = mb.square(x=x_sub_mean)
    x_var = mb.reduce_mean(x=x_sub_mean_square, axes=axes, keep_dims=keepdim)

    context.add(x_var, torch_name=node.name)
RahulBhalley commented 4 months ago

So, only the assertions were wrong. 🤔

akshaydevelops commented 4 months ago

Yeah!! looks like.

TachibanaYoshino commented 3 weeks ago

Oh, this is indeed a problem, and the operator has not been implemented for a long time. I also encountered this when I reproduced the tensorflow code of AnimeGANv3 with pytorch. In fact, according to the calculation formula of variance, the variance can be implemented as follows:

import torch

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if __name__ == '__main__':
    x = torch.randn(1, 3, 512, 512).to(device)

    t_sigma, t_mean = torch.var_mean(x, dim=[2, 3], keepdim=True, unbiased=False)

    # Reproduce torch.var using torch.mean
    mean = torch.mean(x, dim=[2, 3], keepdim=True)
    var = torch.mean((x - mean) ** 2, dim=[2, 3], keepdim=True)

    print(torch.sum(var-t_sigma))
    print(torch.sum(mean-t_mean))

The test results are as follows: image

The mean and variance calculation operator tf.nn.moments() in TF can be converted successfully. In pytorch, torch.var_mean() and torch.var() do not support conversion yet.