keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.06k stars 19.35k forks source link

Shape error for some use cases of `binary_crossentropy`. #19662

Open albertaillet opened 2 weeks ago

albertaillet commented 2 weeks ago

Problem

When wrapping the binary_crossentropy loss function in another keras.losses.Loss, it no longer supports targets with an flat shape and requires a shape of form (..., 1). This does not happen when it is simply wrapped in a function or a class with a __call__() method.

How to reproduce

The following script can be used to reproduce this error.

import keras

def fit_model_with_loss(loss):
    model = keras.Sequential([keras.layers.Dense(1, activation="sigmoid")])
    model.compile(optimizer="sgd", loss=loss, metrics=["accuracy"])
    model.fit(x, y, batch_size=16, epochs=2)

x = keras.random.uniform((32, 1))
y = keras.random.randint((32,), 0, 1)

loss = keras.losses.get("binary_crossentropy")
fit_model_with_loss(loss)  # works fine

def loss_wrapped_with_function(*args, **kwargs):
    return loss(*args, **kwargs)

fit_model_with_loss(loss_wrapped_with_function)  # works fine

class LossWrapper:
    def __init__(self, loss) -> None:
        super().__init__()
        self.loss = loss

    def __call__(self, *args, **kwargs):
        return self.loss(*args, **kwargs)

fit_model_with_loss(LossWrapper(loss))  # works fine

class LossWrapperInherit(keras.losses.Loss):
    def __init__(self, loss) -> None:
        super().__init__()
        self.loss = loss

    def call(self, *args, **kwargs):
        return self.loss(*args, **kwargs)

fit_model_with_loss(LossWrapperInherit(loss))  # gets a shape error

The error is the following:

File "/keras/keras/src/losses/loss.py", line 43, in __call__
    losses = self.call(y_true, y_pred)
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/keras/reproduce_keras_error.py", line 45, in call
    return self.loss(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/keras/keras/src/losses/losses.py", line 1782, in binary_crossentropy
    ops.binary_crossentropy(y_true, y_pred, from_logits=from_logits),
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/keras/keras/src/ops/nn.py", line 1398, in binary_crossentropy
    return backend.nn.binary_crossentropy(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/keras/keras/src/backend/jax/nn.py", line 518, in binary_crossentropy
    raise ValueError(
ValueError: Arguments `target` and `output` must have the same shape. Received: target.shape=(16,), output.shape=(16, 1)

Is there a recommended way?

In case this is an expected behaviour, what is the recommended way to wrap a loss function as a keras.losses.Loss class and handle both flat and (..., 1) target shapes?

fchollet commented 2 weeks ago

Just like in layers, you're supposed to override call(), not __call__(). You could override __call__(), but by doing so you miss a bit of built-in functionality, including auto-broadcasting. So you can just override call() and call self.loss(y_true, y_pred) there.

albertaillet commented 2 weeks ago

I agree, however in the example, I think the one not working is the one that overrides call() in the recommended way, if I am not mistaken.

SuryanarayanaY commented 1 week ago

I agree, however in the example, I think the one not working is the one that overrides call() in the recommended way, if I am not mistaken.

Hi @albertaillet ,

I have reproduced the reported error with overriding call method. Attached gist for reference.