keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.91k stars 19.45k forks source link

Conv2D Layer in PyTorch and Keras. Issue with padding elements. #19291

Closed abhaskumarsinha closed 6 months ago

abhaskumarsinha commented 7 months ago

Hello there everyone,

I got stuck into some problems while working with the conv2d layers of Keras and the Conv2d layer of PyTorch used in the YOLOv9 Model.

When I compared via creating a conv2d layer in Keras as well as the one used in YOLOv9, I found certain issues with the value mismatch of the padded element of both of the models. Here's the code for both conv2d layers:

def autopad(k, p=None, d=1):  # kernel, padding, dilation
    # Pad to 'same' shape outputs
    if d > 1:
        k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k]  # actual kernel-size
    if p is None:
        p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-pad
    return p

conv_pyt = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)

# Here c1 = 3, c2 = 64, kernel_size = 3, stride = 2, padding = 1 (from autopad above), dilation = 1, groups = 1, padding_mode = 'zeros'

Now creating a Keras Conv2D layer to handle it in the same format!

input = keras.Input(shape=(3, 128, 128))
tp = keras.ops.transpose(input, (0, 3, 2, 1))
conv = keras.layers.Conv2D(64, kernel_size = 3, strides = 2, padding= 'same')(tp)
output = keras.ops.transpose(conv, (0, 3, 2, 1))

keras_conv = keras.Model(inputs=input, outputs=output)

Now testing them with ones of the same dims for both of them results in :

array([[[[False, False, False, ..., False, False, False],
         [False,  True,  True, ...,  True,  True, False],
         [False,  True,  True, ...,  True,  True, False],
         ...,
         [False,  True,  True, ...,  True,  True, False],
         [False,  True,  True, ...,  True,  True, False],
         [False, False, False, ..., False, False, False]],

        [[False, False, False, ..., False, False, False],
         [False,  True,  True, ...,  True,  True, False],
         [False,  True,  True, ...,  True,  True, False],
         ...,
         [False,  True,  True, ...,  True,  True, False],
         [False,  True,  True, ...,  True,  True, False],
         [False, False, False, ..., False, False, False]],
...

with accuracy up to 3 decimal places for both outputs.

Clearly the elements in the edges and corners mismatch for both of the models. I believe this is probably due to padding issue in both of the models.

Here's a reproducible output to play around: https://colab.research.google.com/drive/1qibf8UKo2GQBkyh0Ppcz9Jros30ltCUb?usp=sharing

Is that a bug or a feature in Keras?

abhaskumarsinha commented 7 months ago

This code is supposed to be part of YOLOv7 too. But I've got another big problem after that. Let me add the code first.

def autopad(k, p=None, d=1):  # kernel, padding, dilation
    # Pad to 'same' shape outputs
    if d > 1:
        k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k]  # actual kernel-size
    if p is None:
        p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-pad
    return p

class TFPad(keras.layers.Layer):
    # Pad inputs in spatial dimensions 1 and 2
    def __init__(self, pad):
        super().__init__()
        if isinstance(pad, int):
            self.pad = tf.constant([[0, 0], [pad, pad], [pad, pad], [0, 0]])
        else:  # tuple/list
            self.pad = tf.constant([[0, 0], [pad[0], pad[0]], [pad[1], pad[1]], [0, 0]])

    def call(self, inputs):
        return tf.pad(inputs, self.pad, mode='constant', constant_values=0)

# SO what is this TFPad class doing here?
# I dicided to use the official pytorch weights without training the models from scratch.
# And as tensorflow and pytorch has different padding mechanisom when stride is greater than 1,
# we are forced to use this simple hack of manually padding the inputs.
# And the credicts goes to ultralytrics for this.
# https://github.com/ultralytics/yolov5/blob/c4c0ee8fc35937cfa940fdaaaf6b9660f5b355f5/models/tf.py#L72
@keras.utils.register_keras_serializable()
class Conv(keras.layers.Layer):
    def __init__(self, filters, kernel_size=1, strides=1, padding=None, groups=1,
                 act=True, name='_', deploy=False, **kwargs):
        super(Conv, self).__init__(name=name, **kwargs)

        self.deploy = deploy
        self.filters = filters
        self.kernel_size = kernel_size
        self.padding = padding
        self.strides = strides
        self.groups = groups

        if strides==1:
            self.conv = keras.layers.Conv2D(filters, kernel_size, strides, padding='same', 
                               groups=groups, use_bias=False, name=f'cv')
        else:
            self.conv = keras.Sequential([
                    TFPad(autopad(kernel_size,None)), 
                    keras.layers.Conv2D(filters, kernel_size, strides, padding='Valid', 
                               groups=groups, use_bias=False, name=f'cv')
                ])
        self.bn = keras.layers.BatchNormalization(name=f'bn') if not deploy else None
        self.act = keras.activations.swish if act is True else (act if isinstance(act, keras.acivations) else tf.identity)

    def call(self, x):
        return self.act(self.fused_conv(x)) if self.deploy else self.act(self.bn(self.conv(x)))

    def get_config(self):
        config = super(Conv, self).get_config()
        config.update({'filters': self.filters, 'kernel_size': self.kernel_size, 'padding':self.padding,
                       'strides': self.strides, 'groups': self.groups})
        return config

with such additions, the Keras shouldn't remain backend invariant. That is, if I use this code using TF backend, and then change it back to PyTorch, the code would break!

The padding seems different in PyTorch and TensorFlow.

haifeng-jin commented 6 months ago

Hi @abhaskumarsinha,

Thanks for the issue! You are using keras.layers.Conv2D(..., strides=2, padding="same") which does not have a direct equivalent when using torch.nn.Conv2d.

I have the following snippet, showing the equivalent in torch using torch.nn.functional.pad.

import os

os.environ["KERAS_BACKEND"] = "torch"

import numpy as np
import torch

import keras

keras_layer = keras.layers.Conv2D(64, kernel_size=3, strides=2, padding="same")
keras_layer.build(input_shape=(None, 4, 4, 3))

torch_layer = torch.nn.Conv2d(
    3,
    64,
    kernel_size=3,
    stride=2,
    padding="valid",
)
weight = torch_layer.weight.detach().numpy()
bias = torch_layer.bias.detach().numpy()

keras_layer.set_weights([np.transpose(weight, axes=(2, 3, 1, 0)), bias])

input = np.random.rand(1, 4, 4, 3).astype("float32")
keras_output = keras_layer(input).detach().numpy()
# print(keras_output)

torch_input = torch.from_numpy(np.transpose(input, axes=(0, 3, 1, 2)))
torch_input = torch.nn.functional.pad(torch_input, (0, 1, 0, 1))
torch_output = np.transpose(
    torch_layer(torch_input).detach().numpy(), axes=(0, 2, 3, 1)
)
# print(torch_output)

print(np.isclose(keras_output, torch_output))

For the second problem you posted about code broken with torch backend, it is because torch backend does not support the tf ops you directly used in TFpad layer.

Let me know if this helps solving the problem or not.

Thanks!

google-ml-butler[bot] commented 6 months ago

Are you satisfied with the resolution of your issue? Yes No