Open albertaillet opened 2 weeks ago
Just like in layers, you're supposed to override call()
, not __call__()
. You could override __call__()
, but by doing so you miss a bit of built-in functionality, including auto-broadcasting. So you can just override call()
and call self.loss(y_true, y_pred)
there.
I agree, however in the example, I think the one not working is the one that overrides call()
in the recommended way, if I am not mistaken.
I agree, however in the example, I think the one not working is the one that overrides
call()
in the recommended way, if I am not mistaken.
Hi @albertaillet ,
I have reproduced the reported error with overriding call
method. Attached gist for reference.
Problem
When wrapping the
binary_crossentropy
loss function in anotherkeras.losses.Loss
, it no longer supports targets with an flat shape and requires a shape of form(..., 1)
. This does not happen when it is simply wrapped in a function or a class with a__call__()
method.How to reproduce
The following script can be used to reproduce this error.
The error is the following:
Is there a recommended way?
In case this is an expected behaviour, what is the recommended way to wrap a loss function as a
keras.losses.Loss
class and handle both flat and(..., 1)
target shapes?