google-deepmind / distrax

Apache License 2.0
535 stars 32 forks source link

Tanh producing values smaller than -1 #235

Closed StoneT2000 closed 1 year ago

StoneT2000 commented 1 year ago

I recently stumbled on this issue here. When I run the following

import jax
import jax.numpy as jnp
import distrax
log_std = jnp.array([-100])
a = jnp.array([-8.51089])
dist = distrax.Normal(a, jnp.exp(log_std))
dist = distrax.Transformed(distribution=dist, bijector=distrax.Tanh())
s = dist.sample_and_log_prob(seed=jax.random.PRNGKey(0))
print(s)

It prints

(Array([-1.0000001], dtype=float32), Array(114.67845, dtype=float32))

Any idea? This seems like it may be related to https://github.com/tensorflow/tensorflow/issues/35435 but I'm not sure how distrax is impacted by tensorflow probability package directly.

Currently using version 0.1.3 of distrax

franrruiz commented 1 year ago

This seems like a numerical representation issue that is independent of distrax and is also platform-dependant.

For example, I just tried the following code on a Google Colab:

a = jnp.array([-8.51089])
print(jnp.tanh(a))

It prints

Array([-1.0000001], dtype=float32)

when working on a GPU, or

Array([-1.], dtype=float32)

when working on a CPU.

To avoid numerical issues, we strongly recommend clipping your input or output values to ensure they remain within some reasonable range.