Open vwxyzjn opened 1 year ago
Also ran into Tanh
bijector + Transformed
causing NaNs. #7 has a workaround.
Did some digging into this because it was really bothering me and turns out the behaviour seems somewhat expected / it's not really distrax' fault. I think the conclusion I came to is pretty much what this comment in #7 describes as well, but perhaps it'd be worth documenting here in greater detail since this issue is still open.
If you print the sampled actions in this code snippet rather than their sum, you will notice that specifically at index [0,0] the value is 1.0000001
. Then, calling jnp.arctanh()
as part of the inverse process of the Tanh bijector, you get a nan
.
Obviously such a value is outside the range of tanh
and shouldn't occur but it does because of numerical precision. Switching the precision to 64bit with jax.config.update("jax_enable_x64", True)
, you don't get such values and the code snippet works fine.
As a sidenote, the reason custom_log_prob()
returns a value here is because it doesn't actually take the arctanh()
of the sampled action. If you look closely at the snippet, the actual function discards the gaussian_action
argument it takes, and reinitialises it by sampling from a normal distribution, which is wrong (it only works if the same rng key was used for the actions whose log prob is being computed). If you cut that line out, it too returns nan
just like tfp and distrax.
Therefore, this isn't something that can be fixed on distrax' end. The reason #7's workaround works is because it computes the log prob of the sampled actions using the pre-tanh value (which is readily available since the operation includes a forward sampling pass) and the numerical precision never becomes a problem. Calling log_prob()
on pre-sampled actions, however, (so the pre-tanh value isn't readily available), requires a call to arctanh()
and results in the problem above unless 64-bit precision is used.
To conclude, the ways around this I can think of are to either:
log_prob()
and only use sample_and_log_prob()
(which depending on your use case might actually be possible, e.g. in RL for SAC)MultivariateNormalDiag
, store the pre-tanh values, then compute actions as actions = jnp.tanh(pre_tanh_actions)
and the log probs as:
log_prob = normal_dist.log_prob(pre_tanh_actions) - jnp.sum(2 * (jnp.log(2) - pre_tanh_actions - jax.nn.softplus(-2 * pre_tanh_actions)), axis=-1)
log_std_max
in your network or use any other tools that would bound the mean and std to values that lead to more reasonable numbers.
Hello thanks for this awesome repo! We have had a slight issue with using distrax which creates
nan
at https://github.com/vwxyzjn/cleanrl/pull/300. See the following reproduction script: