keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.06k stars 19.35k forks source link

Layernorm not supporting axis [-2, 3] #19642

Closed lllllllllaa closed 2 weeks ago

lllllllllaa commented 2 weeks ago

Hi, I wanted to normalise my output on the -2 and -3 axis, (image height and width), however, it seems that the with rms_scaling=true, the self.gamma is not broadcasted to same shape as layer input causing this error,

inputs shape: (1, 1920, 1200, 3)
inv shape: (1, 1, 1, 3)
gamma_cast shape: (1920, 1200)
inv shape: (1, 1920, 1200, 3)
2024-04-30 13:50:54.238379: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: INVALID_ARGUMENT: Incompatible shapes: [1,1920,1200,3] vs. [1920,1200]
Traceback (most recent call last):
  File "C:\Users\88bbh\PycharmProjects\AI\tempt.py", line 10, in <module>
    layer(np.zeros((1, 1920, 1200, 3)))
  File "C:\Users\88bbh\PycharmProjects\AI\venv\lib\site-packages\keras\src\utils\traceback_utils.py", line 122, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "C:\Users\88bbh\PycharmProjects\AI\venv\lib\site-packages\tensorflow\python\framework\ops.py", line 5983, in raise_from_not_ok_status
    raise core._status_to_exception(e) from None  # pylint: disable=protected-access
tensorflow.python.framework.errors_impl.InvalidArgumentError: Exception encountered when calling LayerNormalization.call().

{{function_node __wrapped__Mul_device_/job:localhost/replica:0/task:0/device:CPU:0}} Incompatible shapes: [1,1920,1200,3] vs. [1920,1200] [Op:Mul] name: 

Arguments received by LayerNormalization.call():
  • inputs=tf.Tensor(shape=(1, 1920, 1200, 3), dtype=float32)

code to reproduce

layer = keras.layers.LayerNormalization(axis=[-3, -2], rms_scaling=True)
layer.build([None, 1920, 1200, 3])
layer(np.zeros((1, 1920, 1200, 3)))

the error is in layernorm call method

        if self.rms_scaling:
            # Calculate outputs with only variance and gamma if rms scaling
            # is enabled
            # Calculate the variance along self.axis (layer activations).
            variance = ops.var(inputs, axis=self.axis, keepdims=True)
            inv = ops.rsqrt(variance + self.epsilon)
            print("inputs shape:", inputs.shape)
            print("inv shape:", inv.shape)
            print("gamma_cast shape:", self.gamma.shape)
            print("inv shape:", (inputs * inv).shape)
            outputs = inputs * inv * ops.cast(self.gamma, inputs.dtype)

the error can be fixed by changing

outputs = inputs * inv * ops.cast(self.gamma, inputs.dtype)
to
outputs = inputs * inv * ops.cast(_broadcast(self.gamma), inputs.dtype)

please fix it in the next update thank you

SuryanarayanaY commented 2 weeks ago

Hi @lllllllllaa ,

Thanks for reporting. I acknowledge the issue and proposed fix seems correct.Proposed fix on attached PR.

Thanks!

google-ml-butler[bot] commented 2 weeks ago

Are you satisfied with the resolution of your issue? Yes No