Closed PatReis closed 10 months ago
Thanks for the report. This looks like an oversight. The expectation is that Loss
should have a configurable dtype
argument, defaulting to floatx()
. We'll fix it.
This is fixed at HEAD.
Do note that when using the JAX backend, there are no 64-precision dtypes. jnp.array(..., dtype="float64")
returns a float32.
I have noticed that the
Loss
base class casts both inputs to the loss tobackend.floatx()
regardless of the input dtype.In some special cases, however, most of the model (with
backend.floatx()
) should run in lower precision but some parts and especially the loss requires higher precision like e.g. "float64" thenLoss
downcasts to "float32" which could hinder tight convergence. Thetf.keras
behaviour, if I am not mistaken, is different and tries to maintain highest precision loss.My question: Is there is specific reason regarding speed performance to strictly keep everything to floatx in the loss, or can the loss be also higher/different precision?
Of course it is easy to write an own Loss function or even copy paste the source code and change it. Nonetheless, I would like to inquire if you plan a flexible cast in
Loss
or a simpledtype=None
paramter as in metrics? If you do not mind I can try a pull request with an optional dtype parameter for the base class and the loss wrapper only.