keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.97k stars 19.47k forks source link

Predict fails on a model containing the Attention layer. #20429

Open HGS-mbayer opened 6 days ago

HGS-mbayer commented 6 days ago

Running predict on a model containing an Attention layer causes a RuntimeError due to a dimension issue.

Example Code

Here is a dummy model to reproduce the issue.

import keras
import numpy as np

INPUT_SHAPE = (128, 128, 3)
NUM_CLASSES = 3

def create_model(dims: tuple[int, int, int], num_classes: int):
    width, height, bands = dims

    inputs = keras.layers.Input((width, height, bands))

    conv1 = keras.layers.Conv2D(8, (3, 3), padding='same')(inputs)
    bn = keras.layers.BatchNormalization()(conv1)
    act = keras.layers.Activation('relu')(bn)
    pool1 = keras.layers.MaxPooling2D((2, 3), strides=(2, 2))(act)
    attention = keras.layers.Attention(use_scale=False, score_mode='dot')(
        [pool1, pool1]
    )
    output = keras.layers.Conv2D(
        num_classes, (1, 1), padding='same', activation='softmax'
    )(attention)

    model = keras.models.Model(inputs=inputs, outputs=output)

    return model

model = create_model(INPUT_SHAPE, NUM_CLASSES)

data = np.random.rand(1, *INPUT_SHAPE)

output = model.predict(data)

Training also appears to fail:

model.compile(optimizer=keras.optimizers.Adam(), loss=keras.losses.MeanSquaredError())
model.fit(data, data)

Traceback

RuntimeError                              Traceback (most recent call last)
Cell In[19], line 33
     29 model = create_model(INPUT_SHAPE, NUM_CLASSES)
     31 data = np.random.rand(1, *INPUT_SHAPE)
---> 33 output = model.predict(data)

File ...\env\lib\site-packages\keras\src\utils\traceback_utils.py:122, in filter_traceback.<locals>.error_handler(*args, **kwargs)
    119     filtered_tb = _process_traceback_frames(e.__traceback__)
    120     # To get the full stack trace, call:
    121     # `keras.config.disable_traceback_filtering()`
--> 122     raise e.with_traceback(filtered_tb) from None
    123 finally:
    124     del filtered_tb

File ...\env\lib\site-packages\torch\nn\modules\module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ...\env\lib\site-packages\torch\nn\modules\module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ...\env\lib\site-packages\torch\nn\modules\module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ...\env\lib\site-packages\torch\nn\modules\module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

RuntimeError: Exception encountered when calling Attention.call().

permute(sparse_coo): number of dimensions in the tensor input does not match the length of the desired ordering of dimensions i.e. input.dim() = 4 is not equal to len(dims) = 3

Arguments received by Attention.call():
  • inputs=['torch.Tensor(shape=torch.Size([1, 64, 63, 8]), dtype=float32)', 'torch.Tensor(shape=torch.Size([1, 64, 63, 8]), dtype=float32)']
  • mask=['None', 'None']
  • training=False
  • return_attention_scores=False
  • use_causal_mask=False
NischitKumar commented 6 days ago

@sachinprasadhs I'm new to open source contributing and have some experience with Tensorflow and Keras, with a decent Machine Learning Background. I would love to contribute to the repo and learn in the process. Your inputs would be really helpful! Best regards, -Nischit