keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.61k stars 19.42k forks source link

layers.GRU returns wrong shaped output with GPU #20173

Open Jonii opened 2 weeks ago

Jonii commented 2 weeks ago

I opened this on tensorflow repo, and was told to move it here: https://github.com/tensorflow/tensorflow/issues/74475

The short of it, gru, at least on google colab(keras 3.4.1) returns wrong things when run with gpu available.

Minimal way to reproduce here:

import tensorflow as tf

class TestModel(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.gru = tf.keras.layers.GRU(10, return_sequences=True, return_state=True)

    def call(self, inputs):
        return self.gru(inputs)

# Create and test the model
model = TestModel()
test_input = tf.random.uniform((2, 3, 5))  # Batch size = 2, sequence length = 3, feature size = 5
output = model(test_input)
print("Output types and shapes:", [(type(o), o.shape) for o in output])

This prints

With GPU:

Output types and shapes: [(<class 'tensorflow.python.framework.ops.EagerTensor'>, TensorShape([2, 3, 10])), (<class 'tensorflow.python.framework.ops.EagerTensor'>, TensorShape([10])), (<class 'tensorflow.python.framework.ops.EagerTensor'>, TensorShape([10]))]

With CPU:

Output types and shapes: [(<class 'tensorflow.python.framework.ops.EagerTensor'>, TensorShape([2, 3, 10])), (<class 'tensorflow.python.framework.ops.EagerTensor'>, TensorShape([2, 10]))]

CPU behavior seems correct.

**Edited to add, I do not have the ability to test gpu behavior outside of google colab, so this might be a bug that's been fixed on the latest version, or due to colab-specific misconfiguration.

sachinprasadhs commented 2 weeks ago

I was able to reproduce the reported behavior, attaching the Gist here

with the Torch backend, it's producing the expected outcome as below Output types and shapes: [(<class 'torch.Tensor'>, torch.Size([2, 3, 10])), (<class 'torch.Tensor'>, torch.Size([2, 10]))]

mattdangerw commented 2 weeks ago

Probably and issue with the cudnn specific implementation on the tf backend, which is pretty dense. I will take a look.

AdityaMayukhSom commented 2 weeks ago

Similar issue happening in case of running Keras with Tensorflow backend on desktop. Hidden states of the individual element in a batch are returned as a tuple of the GRU output and not as a Tensor with first dimension equal to batch size.