keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.66k stars 19.42k forks source link

Using `+` in custom residual Keras `Layer` does not create correct model graph #16098

Closed jacoblubecki closed 1 year ago

jacoblubecki commented 2 years ago

This ticket was originally filed here: https://github.com/tensorflow/tensorflow/issues/54436

They asked me to move it over here and the templates were about the same so I have copy/pasted it almost exactly. Please let me know if anything is missing.

System information

Also reproducible on various hosted Jupyter environments (Kaggle, Colab) with and without GPU.

Describe the current behavior

In custom residual block implementation with keras APIs, + yields broken graph.

def call(self, x, training=False):
    for block in self.blocks:
        h = x
        for conv in block:
            h = conv(h, training=training)

        x = x + h

    return x

This is the resulting graph: res-block-broken

Describe the expected behavior

This code produces the correct graph, where each add is a separate instance of keras.layers.Add. The + operator should produce the same graph.

def call(self, x, training=False):
    for block, add in zip(self.blocks, self.adds):
        h = x
        for conv in block:
            h = conv(h, training=training)

        x = add([x, h])

    return x

res-block-working

I will add that this isn't just a visualization issue. My model would not train until after I identified this problem and applied the fixed implementation described above. This was causing serious issues with my gradients and the model could not learn because it was too deep without the skip connections.

Contributing

Standalone code to reproduce the issue

from typing import Optional, Tuple, Union

import tensorflow as tf
import tensorflow.nn
import tensorflow.keras.backend as K
import tensorflow.keras.layers as L
from tensorflow import keras

def plot_model(model, shape):
    inputs = keras.Input(shape[1:])
    ones = tf.ones(shape)
    model(ones)  # I think needed to properly init graph for plotting
    outputs = model.call(inputs)
    wrapped_model = keras.Model(inputs, outputs)
    return tensorflow.keras.utils.plot_model(
        wrapped_model, expand_nested=True, show_shapes=True)

class ConvBnAct(L.Layer):

    def __init__(
        self,
        out_channels: int,
        kernel_size: Union[int, Tuple[int]],
        stride: Union[int, Tuple[int]],
        activation: Optional[str] = 'swish',
        use_bias=False,
        use_batch_norm=True,
        data_format='channels_last'
            ):
        super().__init__()

        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.use_bias = use_bias
        self.data_format = data_format

        self.activation = L.Activation(activation)
        self.act_type = activation

        bn_axis = 1 if data_format == 'channels_first' else -1
        self.batch_norm = L.BatchNormalization(
            axis=bn_axis) if use_batch_norm else None

    def build(self, input_shape):
        self.conv = L.Conv2D(
            self.out_channels,
            self.kernel_size,
            input_shape=input_shape[1:],
            padding='same',
            strides=self.stride,
            activation=None,
            use_bias=self.use_bias,
            data_format=self.data_format,
            )

    def call(self, inputs, training=False):
        x = self.conv(inputs)

        if self.batch_norm:
            x = self.batch_norm(x, training=training)

        if self.activation:
            x = self.activation(x)

        return x

class ResBlock(L.Layer):

    def __init__(
        self,
        blocks: int,
        shortcut=True,
        data_format='channels_last'
    ):
        super().__init__()
        self.n_blocks = blocks
        self.shortcut = shortcut
        self.data_format = data_format

    def build(self, input_shape):
        channel_axis = 1 if self.data_format == 'channels_first' else -1
        channels = input_shape[channel_axis]

        self.blocks = []
        for i in range(self.n_blocks):
            block = [
                ConvBnAct(channels, kernel_size=1, stride=1, data_format=self.data_format),
                ConvBnAct(channels, kernel_size=3, stride=1, data_format=self.data_format)
                ]
            self.blocks.append(block)

    def call(self, x, training=False):
        for block in self.blocks:
            h = x
            for conv in block:
                h = conv(h, training=training)

            x = x + h if self.shortcut else h

        return x

if __name__ == '__main__':
    r = ResBlock(2, True)
    plot_model(r, (1, 24, 24, 3))
jacoblubecki commented 2 years ago

Copied from original ticket, but some additional data pointing to a Keras-specific issue...


I replaced my ConvBnAct with the following:

class ConvBnAct(object):

    def __init__(
        self,
        out_channels: int,
        kernel_size: Union[int, Tuple[int]],
        stride: Union[int, Tuple[int]],
        activation: Optional[str] = 'swish',
        use_bias=False,
        use_batch_norm=True,
        data_format='channels_last'
            ):
        super().__init__()
        self.kernel = tf.ones((kernel_size, kernel_size, 3, 3))

    def __call__(self, inputs, training=False):
        x = tf.nn.conv2d(inputs, self.kernel, strides=1, padding='SAME')
        return x

Then I used tensorboard to trace two versions of ResBlock.

Keras version using only tensorflow APIs (still inherits from Keras Layer, but otherwise only using tf namespace):

class ResBlock(L.Layer):

    def __init__(self, blocks: int, data_format='channels_last'):
        super().__init__()
        self.n_blocks = blocks

        self.blocks = []
        for i in range(self.n_blocks):
            block = [
                ConvBnAct(3, kernel_size=1, stride=1, data_format=self.data_format),
                ConvBnAct(3, kernel_size=3, stride=1, data_format=self.data_format)
                ]
            self.blocks.append(block)

    def call(self, x, training=False):
        for block in self.blocks:
            h = x
            for conv in block:
                h = conv(h, training=training)

            x = x + h

        return x

And pure tensorflow version (with no reference to any Keras APIs):

class ResBlock(object):

    def __init__(self, blocks: int, data_format='channels_last'):
        super().__init__()
        self.n_blocks = blocks

        self.blocks = []
        for i in range(self.n_blocks):
            block = [
                ConvBnAct(3, kernel_size=1, stride=1, data_format=data_format),
                ConvBnAct(3, kernel_size=3, stride=1, data_format=data_format)
                ]
            self.blocks.append(block)

    def __call__(self, x, training=False):
        for block in self.blocks:
            h = x
            for conv in block:
                h = conv(h, training=training)

            x = x + h

        return x

The Keras version is still broken on tensorboard, but the pure version with no Keras APIs appears to work fine.

fchollet commented 2 years ago

This appears to be a bug with plot_model, which has been fixed some time ago. You can try again with tf-nightly, and you should be able to see the expected model plots. Note that the issue was only affecting the plots in the first place, not the models themselves.

Also note that you should never use model.call directly. call() is an implementer-facing API only. Always use model() instead.

jacoblubecki commented 2 years ago

@fchollet Can you link me to the related ticket?

This appears to be a bug with plot_model, which has been fixed some time ago. You can try again with tf-nightly, and you should be able to see the expected model plots. Note that the issue was only affecting the plots in the first place, not the models themselves.

Even so, the model wouldn't train before I made the change from + -> add and so I suspect it may be more. Other than this change, the environment was pinned, deterministic ops were enabled, and random number generators were all seeded.

Still, I did update to tf-nightly and the graph was correct. Maybe I was just extremely unlucky in my training or inadvertently fixed another bug without noticing?

Also note that you should never use model.call directly. call() is an implementer-facing API only. Always use model() instead.

I definitely don't do this normally, the plotting code was just showing me a completely flat model (Input -> Resblock -> Out) so I just tried stuff until it started working. Still, good to know.

SuryanarayanaY commented 1 year ago

Hello, Thank you for reporting an issue.

We're currently in the process of migrating the new Keras 3 code base from keras-team/keras-core to keras-team/keras. Consequently, This issue may not be relevant to the Keras 3 code base . After the migration is successfully completed, feel free to reopen this Issue at keras-team/keras if you believe it remains relevant to the Keras 3 code base. If instead this Issue is a bug or security issue in legacy tf.keras, you can instead report a new issue at keras-team/tf-keras, which hosts the TensorFlow-only, legacy version of Keras.

To know more about Keras 3, please read https://keras.io/keras_core/announcement/