Closed StoneT2000 closed 1 year ago
This seems like a numerical representation issue that is independent of distrax and is also platform-dependant.
For example, I just tried the following code on a Google Colab:
a = jnp.array([-8.51089])
print(jnp.tanh(a))
It prints
Array([-1.0000001], dtype=float32)
when working on a GPU, or
Array([-1.], dtype=float32)
when working on a CPU.
To avoid numerical issues, we strongly recommend clipping your input or output values to ensure they remain within some reasonable range.
I recently stumbled on this issue here. When I run the following
It prints
Any idea? This seems like it may be related to https://github.com/tensorflow/tensorflow/issues/35435 but I'm not sure how distrax is impacted by tensorflow probability package directly.
Currently using version 0.1.3 of distrax