apple / coremltools

Core ML tools contain supporting tools for Core ML model conversion, editing, and validation.
https://coremltools.readme.io
BSD 3-Clause "New" or "Revised" License
4.36k stars 630 forks source link

NotImplementedError("rank of input must be 3 or 4!") #1560

Open veeresh-dammur opened 2 years ago

veeresh-dammur commented 2 years ago

Hi,

I am trying to port google's RepNet model into the core ml framework. Model inferencing works fine but core ml throwing the error when we try to perform the conversion from tf model to the core ml framework.

[traceback] (most recent call last): File "/home/veeresh/anaconda3/envs/fraser_tf_26/lib/python3.7/site-packages/coremltools/converters/mil/frontend/tensorflow/converter.py", line 330, in convert_main_graph outputs = convert_graph(self.context, graph, self.outputs) File "/home/veeresh/anaconda3/envs/fraser_tf_26/lib/python3.7/site-packages/coremltools/converters/mil/frontend/tensorflow/convert_utils.py", line 189, in convert_graph add_op(context, node) File "/home/veeresh/anaconda3/envs/fraser_tf_26/lib/python3.7/site-packages/coremltools/converters/mil/frontend/tensorflow/ops.py", line 1951, in SpaceToBatchND raise NotImplementedError("rank of input must be 3 or 4!") NotImplementedError: rank of input must be 3 or 4!

Setup :

Ubuntu 20.04 Python: 3.7.13 Tensorflow: 2.6.2 coremltools :5.2.0

TobyRoseman commented 2 years ago

Your colab notebook is quite long and complicated. Can you figure out a minimal example that reproduces this issue?

veeresh-dammur commented 2 years ago

Hi, please take a look at the below code snippet to reproduce the error. Due to indentation errors caused during pasting the code, I have included the image of the snippet too.

import tensorflow as tf
import tensorflow.keras.layers as layers
import tensorflow.keras.regularizers as regularizers
import coremltools as ct

class RepNet(tf.keras.Model):

    def __init__(self):
        super(RepNet, self).__init__()
        self.base_model_layer_name = 'conv4_block3_out'
        base_model = tf.keras.applications.ResNet50V2(
            include_top=False, weights=None, pooling='max')
        self.base_model = tf.keras.models.Model(
            inputs=base_model.input,
            outputs=base_model.get_layer(self.base_model_layer_name).output)
        self.temporal_conv_layers = [
            layers.Conv3D(512,
                          3,
                          padding='same',
                          dilation_rate=(3, 1, 1),
                          kernel_regularizer=regularizers.l2(1e-6),
                          kernel_initializer='he_normal')]
        self.temporal_bn_layers = [layers.BatchNormalization()
                                   for _ in self.temporal_conv_layers]
        self.num_frames = 64
        self.image_size = 112
        self.temperature = 13.544
        self.dropout_rate = 0.25
        self.l2_reg_weight = 1e-6

    def call(self, x):
        batch_size = tf.shape(x)[0]
        x = tf.reshape(x, [-1, self.image_size, self.image_size, 3])
        x = self.base_model(x)
        h = tf.shape(x)[1]
        w = tf.shape(x)[2]
        c = tf.shape(x)[3]
        x = tf.reshape(x, [batch_size, -1, h, w, c])
        for bn_layer, conv_layer in zip(self.temporal_bn_layers,
                                        self.temporal_conv_layers):
            x = conv_layer(x)
            x = bn_layer(x)
            x = tf.nn.relu(x)
        x = tf.reduce_max(x, [2, 3])
        return x

if __name__ == "__main__":
    print("TensorFlow version", tf.__version__)  # 2.6.2
    print("TensorFlow version", ct.__version__)  # 5.0
    model = RepNet()
    model(tf.zeros((1, 64, 224, 224, 3)))
    print(model.summary())
    core_model = ct.convert(model, convert_to="neuralnetwork", source="tensorflow")
    print()

image

Traceback of the error : image

TobyRoseman commented 2 years ago

Thanks for the smaller example. With TensorFlow 2.8.0 (the newest TensorFlow version we support) and coremltools 6.0b1, I get a different error message and stack trace:

File ~/miniconda3/lib/python3.9/site-packages/coremltools/converters/_converters_entry.py:426, in convert(model, source, inputs, outputs, classifier_config, minimum_deployment_target, convert_to, compute_precision, skip_model_load, compute_units, package_dir, debug)
    423 if specification_version is None:
    424     specification_version = _set_default_specification_version(exact_target)
--> 426 mlmodel = mil_convert(
    427     model,
    428     convert_from=exact_source,
    429     convert_to=exact_target,
    430     inputs=inputs,
    431     outputs=outputs_as_tensor_or_image_types, # None or list[ct.ImageType/ct.TensorType]
    432     classifier_config=classifier_config,
    433     transforms=tuple(transforms),
    434     skip_model_load=skip_model_load,
    435     compute_units=compute_units,
    436     package_dir=package_dir,
    437     debug=debug,
    438     specification_version=specification_version,
    439 )
    441 if exact_target == 'milinternal':
    442     return mlmodel # Returns the MIL program

File ~/miniconda3/lib/python3.9/site-packages/coremltools/converters/mil/converter.py:182, in mil_convert(model, convert_from, convert_to, compute_units, **kwargs)
    143 @_profile
    144 def mil_convert(
    145     model,
   (...)
    149     **kwargs
    150 ):
    151     """
    152     Convert model from a specified frontend `convert_from` to a specified
    153     converter backend `convert_to`.
   (...)
    180         See `coremltools.converters.convert`
    181     """
--> 182     return _mil_convert(model, convert_from, convert_to, ConverterRegistry, MLModel, compute_units, **kwargs)

File ~/miniconda3/lib/python3.9/site-packages/coremltools/converters/mil/converter.py:209, in _mil_convert(model, convert_from, convert_to, registry, modelClass, compute_units, **kwargs)
    206     # To make sure everyone can read and write to this directory (on par with os.mkdir())
    207     _os.chmod(weights_dir, _stat.S_IRWXU | _stat.S_IRWXG | _stat.S_IRWXO)
--> 209 proto, mil_program = mil_convert_to_proto(
    210                         model,
    211                         convert_from,
    212                         convert_to,
    213                         registry,
    214                         **kwargs
    215                      )
    217 _reset_conversion_state()
    219 if convert_to == 'milinternal':

File ~/miniconda3/lib/python3.9/site-packages/coremltools/converters/mil/converter.py:272, in mil_convert_to_proto(model, convert_from, convert_to, converter_registry, **kwargs)
    269 kwargs.setdefault("convert_to", convert_to)
    270 frontend_converter = frontend_converter_type()
--> 272 prog = frontend_converter(model, **kwargs)
    274 if convert_to.lower() != "neuralnetwork":
    275     passes = kwargs.get("transforms", list())

File ~/miniconda3/lib/python3.9/site-packages/coremltools/converters/mil/converter.py:94, in TensorFlow2Frontend.__call__(self, *args, **kwargs)
     91 from .frontend.tensorflow2.load import TF2Loader
     93 tf2_loader = TF2Loader(*args, **kwargs)
---> 94 return tf2_loader.load()

File ~/miniconda3/lib/python3.9/site-packages/coremltools/converters/mil/frontend/tensorflow/load.py:86, in TFLoader.load(self)
     79     dot_string = self._tf_ssa.get_dot_string(
     80         annotation=True, name_and_op_style=True, highlight_debug_nodes=[]
     81     )
     82     graphviz.Source(dot_string).view(
     83         filename="/tmp/ssa_before_tf_passes", cleanup=True
     84     )
---> 86 program = self._program_from_tf_ssa()
     87 logging.debug("program:\n{}".format(program))
     88 return program

File ~/miniconda3/lib/python3.9/site-packages/coremltools/converters/mil/frontend/tensorflow2/load.py:200, in TF2Loader._program_from_tf_ssa(self)
    198 self._run_tf_ssa_passes()
    199 converter = TF2Converter(self._tf_ssa, **self.kwargs)
--> 200 return converter.convert()

File ~/miniconda3/lib/python3.9/site-packages/coremltools/converters/mil/frontend/tensorflow/converter.py:473, in TFConverter.convert(self)
    471 for g_name in self.graph_stack[1:]:
    472     self.context.add_graph(g_name, self.tfssa.functions[g_name].graph)
--> 473 self.convert_main_graph(prog, graph)
    475 # Apply TF frontend passes on Program. These passes are different
    476 # from passes applied to tfssa.
    477 self.tensorflow_passes(prog)

File ~/miniconda3/lib/python3.9/site-packages/coremltools/converters/mil/frontend/tensorflow/converter.py:396, in TFConverter.convert_main_graph(self, prog, graph)
    394         input_var = mb.cast(x=input_var, dtype="fp32", name=name)
    395     self.context.add(name, input_var)
--> 396 outputs = convert_graph(self.context, graph, self.output_names)
    397 ssa_func.set_outputs(outputs)
    398 prog.add_function("main", ssa_func)

File ~/miniconda3/lib/python3.9/site-packages/coremltools/converters/mil/frontend/tensorflow/convert_utils.py:189, in convert_graph(context, graph, outputs)
    185     msg = "Conversion for TF op '{0}' not implemented.\n \n{1}".format(
    186         node.op, node.original_node
    187     )
    188     raise NotImplementedError(msg)
--> 189 add_op(context, node)
    191 if len(node.outputs) > 0:
    192     # set_global / get_global / NoOp has no direct consumer / outputs
    193     x = context[node.name]

File ~/miniconda3/lib/python3.9/site-packages/coremltools/converters/mil/frontend/tensorflow/ops.py:2018, in SpaceToBatchND(context, node)
   2015     x = _reshape_remaining_dimensions_to_canonical_shape(x, remaining_rank)
   2017 if spatial_rank >= 3:
-> 2018     raise NotImplementedError("Rank of spatial shape > 2 is not supported.")
   2020 if spatial_rank == 2:
   2021     # Tensor has shape [B, H, W, C], we can directly use the space_to_batch op by doing
   2022     # [B, H, W, C] -> transpose -> [B, C, H, W] -> space_to_batch -> [B_new, C, H_new, W_new] ->
   2023     # transpose -> [B_new, H_new, W_new, C]
   2024     x = mb.transpose(x=x, perm=[0, 3, 1, 2])

NotImplementedError: Rank of spatial shape > 2 is not supported.
veeresh-dammur commented 2 years ago

Thanks for your reply. I do get the same error with tf=2.8.0, ct=6.0b1. Is there any workaround for this issue? is there any possibility that this bug would be fixed in the upcoming release?