keras-team / tf-keras

The TensorFlow-specific implementation of the Keras API, which was the default Keras from 2019 to 2023.
Apache License 2.0
64 stars 30 forks source link

keras.layers.Conv2D with dilation_rate > 1 becomes SpaceToBatchNd-Conv2D-BatchToSpaceNd in tflite conversion #290

Open james77777778 opened 1 year ago

james77777778 commented 1 year ago

System information.

Describe the problem.

tf.nn.convolution produces SpaceToBatchNd-Conv2D-BatchToSpaceNd pattern when we do tflite conversion with keras.layers.Conv2D (https://github.com/keras-team/keras/blob/master/keras/layers/convolutional/base_conv.py#L247-L263)

Describe the current behavior.

When converting functional model containing keras.layers.Conv2D which has dilation_rate > 1, the SpaceToBatchNd-Conv2D-BatchToSpaceNd appears and the pattern can be visualized with Netron.

Also, we get useless dilation_h_factor and dilation_w_factor because they are always 1.

We can fix it by replacing tf.nn.convolution with tf.nn.conv2d using custom layer but it is inconvenient if we want to further optimize the model with tfmot which encourages users to stick to built-in keras layers.

Describe the expected behavior.

keras.layers.Conv2D with dilation_rate > 1 should be converted as expected instead of SpaceToBatchNd-Conv2D-BatchToSpaceNd

Contributing.

Just one single line modification:

class Conv2D(Conv):
    def __init__(
        ...
    ):
        super().__init__(
            ...
        )

    def convolution_op(self, inputs, kernel):
        ...

        return tf.nn.conv2d(  # => change from tf.nn.convolution
            ...
        )

Standalone code to reproduce the issue.

import tensorflow as tf

x = tf.keras.layers.Input((224, 224, 3))
y = tf.keras.layers.Conv2D(16, 3, 1, padding='same', dilation_rate=3, use_bias=False)(x)
y = tf.keras.layers.BatchNormalization()(y)
y = tf.keras.layers.ReLU()(y)

model = tf.keras.Model(x, y)
model.summary()
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
with open('model.tflite', 'wb') as f:
    f.write(tflite_model)
sushreebarsa commented 1 year ago

@james77777778 Could you please have a look at the gist in latest TF version and confirm the output? Please feel free to raise the PR to fix the issue ? Thank you!

james77777778 commented 1 year ago

@sushreebarsa I ran the code you provided and verified that the issue still persists.

(visualization from netron)

I will try to raise the PR when I'm back from my vacation... :)