apple / coremltools

Core ML tools contain supporting tools for Core ML model conversion, editing, and validation.
https://coremltools.readme.io
BSD 3-Clause "New" or "Revised" License
4.4k stars 634 forks source link

Adaptive_pooling does not work with EnumeratedShapes. #976

Open starsky opened 3 years ago

starsky commented 3 years ago

The issues:

Trace

In the case of adaptive pooling what we require is output tensor size eg. (2,2) and the kernel size, stride etc. should be computed based on input tensor size. But when I run the conversion I get the following error:

ValueError                                Traceback (most recent call last)
<ipython-input-1-2949dbfe456f> in <module>
     27 model = ct.convert(
     28     traced_model,
---> 29     inputs=[inp])

~/projects/mmhmm/coremltools/coremltools/converters/_converters_entry.py in convert(model, source, inputs, outputs, classifier_config, minimum_deployment_target, **kwargs)
    308             outputs=outputs,
    309             classifier_config=classifier_config,
--> 310             **kwargs
    311         )
    312 

~/projects/mmhmm/coremltools/coremltools/converters/mil/converter.py in _convert(model, convert_from, convert_to, converter_registry, **kwargs)
    130         )
    131     frontend_converter = frontend_converter_type()
--> 132     prog = frontend_converter(model, **kwargs)
    133     common_pass(prog)
    134 

~/projects/mmhmm/coremltools/coremltools/converters/mil/converter.py in __call__(self, *args, **kwargs)
     82         from .frontend.torch import load
     83 
---> 84         return load(*args, **kwargs)
     85 
     86 

~/projects/mmhmm/coremltools/coremltools/converters/mil/frontend/torch/load.py in load(model_spec, debug, **kwargs)
     84         raise e
     85     except Exception as e:
---> 86         raise e
     87 
     88     return prog

~/projects/mmhmm/coremltools/coremltools/converters/mil/frontend/torch/load.py in load(model_spec, debug, **kwargs)
     74 
     75     try:
---> 76         prog = converter.convert()
     77     except RuntimeError as e:
     78         if debug and "convert function" in str(e):

~/projects/mmhmm/coremltools/coremltools/converters/mil/frontend/torch/converter.py in convert(self)
    222 
    223             # Add the rest of the operations
--> 224             convert_nodes(self.context, self.graph)
    225 
    226             graph_outputs = [self.context[name] for name in self.graph.outputs]

~/projects/mmhmm/coremltools/coremltools/converters/mil/frontend/torch/ops.py in convert_nodes(context, graph)
     53             )
     54         else:
---> 55             _add_op(context, node)
     56 
     57         # We've generated all the outputs the graph needs, terminate conversion.

~/projects/mmhmm/coremltools/coremltools/converters/mil/frontend/torch/ops.py in adaptive_avg_pool2d(context, node)
    813             pad_type=pad_type,
    814             pad=pad,
--> 815             name=node.name,
    816         )
    817     else:

~/projects/mmhmm/coremltools/coremltools/converters/mil/mil/ops/registry.py in add_op(cls, **kwargs)
     60             @classmethod
     61             def add_op(cls, **kwargs):
---> 62                 return cls._add_op(op_cls, **kwargs)
     63 
     64             setattr(Builder, op_type, add_op)

~/projects/mmhmm/coremltools/coremltools/converters/mil/mil/builder.py in _add_op(cls, op_cls, **kwargs)
    185         kwargs = {k: v for k, v in kwargs.items() if v is not None}
    186         kwargs = cls._create_input_vars(
--> 187             op_cls.input_spec, kwargs["name"], op_cls, before_op, kwargs
    188         )
    189         new_op = op_cls(**kwargs)

~/projects/mmhmm/coremltools/coremltools/converters/mil/mil/builder.py in _create_input_vars(cls, input_spec, op_name, op_cls, before_op, kwargs)
    144                             )
    145                     elif isinstance(in_type, (ScalarOrTensorInputType, ListOrScalarOrTensorInputType)):
--> 146                         var = cls._add_const(val, new_var_name, before_op)
    147                     else:
    148                         msg = "Cannot convert input '{}' of type {} to Var (op: {})"

~/projects/mmhmm/coremltools/coremltools/converters/mil/mil/builder.py in _add_const(cls, val, name, before_op)
     81     def _add_const(cls, val, name, before_op):
     82         if not is_python_value(val):
---> 83             raise ValueError("Cannot add const {}".format(val))
     84         if any_symbolic(val):
     85             msg = (

ValueError: Cannot add const [is0 - floor(is0/2), is1 - floor(is1/2)]

It seems to me that the shape of the input is represented as sympy expression, but it is not handled well by the converter.

To Reproduce

import torch
import coremltools as ct
import numpy as np
from PIL import Image
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms

class SomeModel(nn.Module):

    def forward(self, inp):
        return F.adaptive_avg_pool2d(inp, output_size=(2,2))

py_model = SomeModel()

# This part is just to create PyTorch traced model
img_size = 500
example_input = np.random.randint(0, 255, (img_size, img_size, 3))
example_input_img = Image.fromarray(example_input, 'RGB')
example_input_tensor = transforms.ToTensor()(example_input_img).view(1, 3, img_size, img_size)
traced_model = torch.jit.trace(py_model, example_input_tensor)

from coremltools import EnumeratedShapes
shp = EnumeratedShapes(shapes=[(1, 3, 200, 200), (1, 3, 500,500)])
inp = ct.TensorType(name="input_1", shape=shp)
model = ct.convert(
    traced_model,
    inputs=[inp])

You can see, that in this example I try to convert a simple model which consists of the single adaptive pooling operation.

System environment (please complete the following information):

anusha66 commented 2 years ago

It seems to me that the shape of the input is represented as sympy expression, but it is not handled well by the converter.

I am facing the same issue with Adaptive pooling. Is there a fix for this?

TobyRoseman commented 1 year ago

This is still an issue in coremltools 6.0