apple / coremltools

Core ML tools contain supporting tools for Core ML model conversion, editing, and validation.
https://coremltools.readme.io
BSD 3-Clause "New" or "Revised" License
4.35k stars 628 forks source link

Can't Convert Keras Model with Flexible Input Length and Dilated Convolution(s) #1389

Open omareldeeb opened 2 years ago

omareldeeb commented 2 years ago

❓Question

I have a keras model that contains some dilated convolutions and whose input shape contains a variable component. When trying to convert the model using coremltools.convert(), I get the following error: Not support dynamic block_shape and paddings for SpaceToBatchND!. If I specify the variable component in the input shape, I get this error instead: non-equal block shape is not yet supported for 4d input.. I managed to create two very simple examples that showcase these two problems:

  1. Variable Input Length
    
    import tensorflow as tf
    import coremltools as ct

inputs = tf.keras.layers.Input(shape=(1, None, 1)) dilated_conv = tf.keras.layers.Conv1D(filters=20, kernel_size=5, dilation_rate=2)(inputs) model = tf.keras.Model(inputs=inputs, outputs=[dilated_conv])

ct.convert(model) # Fails with error: "Not support dynamic block_shape and paddings for SpaceToBatchND!"


2. Fixed Input Length
```python
import tensorflow as tf
import coremltools as ct

inputs = tf.keras.layers.Input(shape=(1, 300, 1))
dilated_conv = tf.keras.layers.Conv1D(filters=20, kernel_size=5, dilation_rate=2)(inputs)
model = tf.keras.Model(inputs=inputs, outputs=[dilated_conv])

ct.convert(model)    # Fails with error: "non-equal block shape is not yet supported for 4d input."

I've done some extensive debugging but I have no idea how to approach this. Initially, padding was set to 'same', which I thought may have been causing the issue; so I tried adding a padding layer before the convolution instead of specifying it when creating the layer. But as you can see from the examples above, it fails even if padding isn't specified at all. I would appreciate any help on this issue! :)

System Information

coremltools.__version__ == '5.1.0' tensorflow.__version__=='2.5.0'

TobyRoseman commented 2 years ago

For the first problem, I suspect the TensorFlow model is not valid. Before trying to convert to Core ML, are you able to get predictions out of that TensorFlow model?

For the second problem, that TensorFlow model looks good. It looks like we don't currently support that functionality. I'm not familiar with this part of of the code base. For the example you give, block_shape is equal to [1, 2]. @jakesabathia2 - can you add any context here?

omareldeeb commented 2 years ago

The original TensorFlow model I'm trying to convert should be valid, I've been running predictions on it normally. The dummy one in the first code block of my initial question (consisting of just an input layer followed by a single dilated convolution) should be fine as well, I just tried running a prediction on it as well, and it works (input just has to be long enough, since no padding is specified).

I should add that both errors come from the same function attempting to convert the SpaceToBatchND Op in the TensorFlow graph. In my original model, block_shape is also equal to [1, 2] at this point of the conversion; but I'm not familiar enough with the entire process to understand what's going on, so any extra context at all would be great! :)

TobyRoseman commented 2 years ago

@omareldeeb - can you share code to get predictions from the first TensorFlow model?

omareldeeb commented 2 years ago

@TobyRoseman Sure:

import numpy as np

sample_input = np.ones(shape=(1, 1, 500, 1), dtype=np.float32)    # The 500 here is totally arbitrary
model(sample_input)    # Returns a tensor of shape (1, 1, 492, 20)

# Another shape for the input
sample_input = np.ones(shape=(1, 1, 1371, 1), dtype=np.float32)
model(sample_input)    # Returns a tensor of shape (1, 1, 1363, 20)
omareldeeb commented 2 years ago

Interesting observation: Replacing the 1d convolutions with 2d convolutions that simulate a 1d convolution seems to eliminate the second problem ("non-equal block shape is not yet supported for 4d input."), i.e.:

inputs = tf.keras.layers.Input(shape=(1, 300, 1))
dilated_conv = tf.keras.layers.Conv2D(filters=20, kernel_size=(1, 5), dilation_rate=(1, 2))(inputs)    # This functions as a 1D Convolution
model = tf.keras.Model(inputs=inputs, outputs=[dilated_conv])

ct.convert(model)    # Works.

Unfortunately, this doesn't fix the first problem with the dynamic block_shape/paddings.