keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.98k stars 19.48k forks source link

mul got incompatible shapes for broadcasting #20219

Closed Isa-Fay closed 3 weeks ago

Isa-Fay commented 2 months ago

I'm trying to run this code in jax backend, and I've got a static inference shape result successfully. However, when I try to execute dynamically, I get this error message, which really confused me a lot. Here is the code and what I got as results.

version: python 3.10 keras 3.5.0 jax 0.4.31

import os
import re
import jax
import numpy as np
os.environ['KERAS_BACKEND']='jax'
import keras

layer = keras.layers.BatchNormalization(
    axis=-1,
    momentum=0.99,
    epsilon=0.001,
    center=False,
    scale=False,
    beta_initializer="zeros",
    gamma_initializer="ones",
    moving_mean_initializer="zeros",
    moving_variance_initializer="ones",
    synchronized=False,
    trainable=True,
    autocast=True,
)

result_static = layer.compute_output_shape([2, 3])

result_dynamic = layer(
    inputs=np.random.rand(*[2, 3]),
    training=True,
    mask=np.random.rand(*[4]),
)

The static result: 5

The dynamic result: 6

sanskarmodi8 commented 2 months ago

Hi @Isa-Fay, I looked into this and found out that the mask argument has a shape of (4,), which is incompatible with the input shape of (2, 3). The mask should typically have the same shape as the input tensor, or be broadcastable to that shape. You can use mask=np.random.rand(*[2, 3]) instead.

Isa-Fay commented 2 months ago

Thanks for your reply! There is a problem with the input, but I think it's worth paying attention to the inconsistencies in the static and dynamic outputs, which may cause confusion and difficulty for users. Hopefully keras could add some input checks instead of passing illegal parameters directly to the backend.

sanskarmodi8 commented 2 months ago

Thank you for your feedback!

I agree that the inconsistency between the static and dynamic outputs can be confusing for users. Adding input checks at the Keras layer level would definitely help prevent these kinds of errors by validating shapes and dimensions before passing them to the backend.

I have created a similar Pull Request #20237 for similar issue #20221 . After I get response on that I will add the val checks in all such Normalization Layers

sachinprasadhs commented 1 month ago

This is now working in Keras-nightly with the above linked PR fix, attaching the expected outcome in the Gist here

github-actions[bot] commented 1 month ago

This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.

github-actions[bot] commented 3 weeks ago

This issue was closed because it has been inactive for 28 days. Please reopen if you'd like to work on this further.

google-ml-butler[bot] commented 3 weeks ago

Are you satisfied with the resolution of your issue? Yes No