tensorflow / ranking

Learning to Rank in TensorFlow
Apache License 2.0
2.74k stars 474 forks source link

Clarifying global floating point policy #357

Open Sinacam opened 11 months ago

Sinacam commented 11 months ago

The issue of float precision affects many computations in tensorflow_ranking, such as https://github.com/tensorflow/ranking/blob/a928e2b1930a1ebcae2c509e3f6ca95941fd1e49/tensorflow_ranking/python/metrics_impl.py#L603-L628

This has been mentioned before in #254, but I want to elaborate on our difficulties. This type of hardcoded dtypes makes it extremely hard to move our programs to use float64. For example, if we use tf.keras.backend.set_floatx('float64') anywhere, we get errors within tensorflow_ranking due to conflicting dtypes.

Will the global floating point policy (tf.keras.mixed_precision.set_global_policy and tf.keras.backend.floatx) be supported? If the official stance on the global policy is to ignore it, can it be documented?