keras-team / keras-core

A multi-backend implementation of the Keras API, with support for TensorFlow, JAX, and PyTorch.
Apache License 2.0
1.27k stars 115 forks source link

Casting dtype in losses.Loss base class #922

Closed PatReis closed 10 months ago

PatReis commented 10 months ago

I have noticed that the Loss base class casts both inputs to the loss to backend.floatx() regardless of the input dtype.

In some special cases, however, most of the model (with backend.floatx() ) should run in lower precision but some parts and especially the loss requires higher precision like e.g. "float64" then Loss downcasts to "float32" which could hinder tight convergence. The tf.keras behaviour, if I am not mistaken, is different and tries to maintain highest precision loss.

My question: Is there is specific reason regarding speed performance to strictly keep everything to floatx in the loss, or can the loss be also higher/different precision?

Of course it is easy to write an own Loss function or even copy paste the source code and change it. Nonetheless, I would like to inquire if you plan a flexible cast in Loss or a simple dtype=None paramter as in metrics? If you do not mind I can try a pull request with an optional dtype parameter for the base class and the loss wrapper only.

fchollet commented 10 months ago

Thanks for the report. This looks like an oversight. The expectation is that Loss should have a configurable dtype argument, defaulting to floatx(). We'll fix it.

fchollet commented 10 months ago

This is fixed at HEAD.

Do note that when using the JAX backend, there are no 64-precision dtypes. jnp.array(..., dtype="float64") returns a float32.