google-deepmind / distrax

Apache License 2.0
535 stars 32 forks source link

`nan` in MultivariateNormalDiag log prob #216

Open vwxyzjn opened 1 year ago

vwxyzjn commented 1 year ago

Hello thanks for this awesome repo! We have had a slight issue with using distrax which creates nan at https://github.com/vwxyzjn/cleanrl/pull/300. See the following reproduction script:

from typing import Sequence

import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax

# import pybullet_envs  # noqa
import tensorflow_probability
from flax.training.train_state import TrainState

tfp = tensorflow_probability.substrates.jax
tfd = tfp.distributions
jax.config.update("jax_platform_name", "cpu")
import distrax

class Actor(nn.Module):
    action_dim: Sequence[int]
    n_units: int = 256
    log_std_min: float = -20
    log_std_max: float = 2

    @nn.compact
    def __call__(self, x: jnp.ndarray):
        x = nn.Dense(self.n_units)(x)
        x = nn.relu(x)
        x = nn.Dense(self.n_units)(x)
        x = nn.relu(x)
        mean = nn.Dense(self.action_dim)(x)
        log_std = nn.Dense(self.action_dim)(x)
        log_std = jnp.clip(log_std, self.log_std_min, self.log_std_max)
        return mean, log_std

# @jax.jit
def custom_log_prob(
    mean: jnp.ndarray,
    log_std: jnp.ndarray,
    subkey: jax.random.KeyArray,
    gaussian_action: jnp.ndarray,
):
    std = jnp.exp(log_std)
    gaussian_action = mean + std * jax.random.normal(subkey, shape=mean.shape)
    log_prob = -0.5 * ((gaussian_action - mean) / std) ** 2 - 0.5 * jnp.log(2.0 * jnp.pi) - log_std
    log_prob = log_prob.sum(axis=1)
    # https://github.com/vwxyzjn/cleanrl/pull/300#issuecomment-1326285592
    log_prob -= jnp.sum(2.0 * (np.log(2.0) - gaussian_action - jax.nn.softplus(-2.0 * gaussian_action)), 1)
    return log_prob

if __name__ == "__main__":
    key = jax.random.PRNGKey(0)
    key, actor_key = jax.random.split(key, 2)
    # with open("test.npy", "rb") as f:
    #     obs = np.load(f)
    obs = jnp.array([[ -0.06284985,  -0.0164921 ,  -0.10846169,   0.28114545,
         -0.28463456,   0.4503281 ,   0.27488193,  -0.0666963 ,
          0.6118138 ,   0.34202537,  -1.262452  ,   0.7542422 ,
         13.809639  ,  -0.6205632 ,  -4.0013294 ,   5.3532414 ,
         11.587792  ],
       [ -0.15303956,   0.9534635 ,  -0.3092537 ,  -0.2033926 ,
          0.03336933,   0.6362027 ,   0.02348915,  -0.32627296,
         -0.29046476,   0.46484601,  -0.42002085,  -3.1616204 ,
          2.247283  ,  14.114895  ,   2.6248324 ,  -1.9809983 ,
        -12.693646  ],
       [ -0.07995494,   0.09804074,  -0.20460981,  -0.13476144,
          0.1701505 ,   0.05989099,  -0.06446445,  -0.22749065,
          0.39946172,   0.42318228,   2.5876977 ,   3.8510017 ,
         -8.23167   ,  -7.292657  ,   7.64345   ,  -9.558817  ,
         -1.9690503 ],
    ])
    # obs = obs[0:5]
    actor = Actor(action_dim=6)
    actor_state = TrainState.create(
        apply_fn=actor.apply,
        params=actor.init(actor_key, obs),
        tx=optax.adam(learning_rate=3e-4),
    )

    key, subkey = jax.random.split(key, 2)
    mean, log_std = actor.apply(actor_state.params, obs)
    action_std = jnp.exp(log_std)
    tfd_dist = tfd.TransformedDistribution(
        tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std)), bijector=tfp.bijectors.Tanh()
    )
    distrax_dist = distrax.Transformed(
        distrax.MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std)), bijector=distrax.Block(distrax.Tanh(), 1)
    )

    # action generation
    gaussian_action = mean + action_std * jax.random.normal(subkey, shape=mean.shape)
    action_custom = jnp.tanh(gaussian_action)
    reverse_action_custom = jnp.arctanh(action_custom)
    action_tfp = tfd_dist.sample(seed=subkey)
    action_distrax = distrax_dist.sample(seed=subkey)

    print("action_custom.sum()", action_custom.sum())
    print("action_tfp.sum()", action_tfp.sum())
    print("action_distrax.sum()", action_distrax.sum())
    print("gaussian_action.sum()", gaussian_action.sum())
    print("reverse_action_custom.sum()", reverse_action_custom.sum())

    # log_prob
    for idx, (action, name) in enumerate(
        zip([action_custom, action_tfp, action_distrax], ["action_custom", "action_tfp", "action_distrax"])
    ):
        log_prob_custom = custom_log_prob(mean, log_std, subkey, jnp.arctanh(action))
        log_prob_tfp = tfd_dist.log_prob(action)
        log_prob_distrax = distrax_dist.log_prob(action)
        print(name)
        print("┣━━ log_prob_custom.sum()", log_prob_custom.sum())
        print("┣━━ log_prob_tfp.sum()", log_prob_tfp.sum())
        print("┣━━ log_prob_distrax.sum()", log_prob_distrax.sum())
action_custom.sum() 2.8352258
action_tfp.sum() 5.978534
action_distrax.sum() 2.8352258
gaussian_action.sum() 34.332348
reverse_action_custom.sum() inf
action_custom
┣━━ log_prob_custom.sum() 58.477264
┣━━ log_prob_tfp.sum() nan
┣━━ log_prob_distrax.sum() nan
action_tfp
┣━━ log_prob_custom.sum() 58.477264
┣━━ log_prob_tfp.sum() 60.565056
┣━━ log_prob_distrax.sum() nan
action_distrax
┣━━ log_prob_custom.sum() 58.477264
┣━━ log_prob_tfp.sum() nan
┣━━ log_prob_distrax.sum() nan
kevinzakka commented 1 year ago

Also ran into Tanh bijector + Transformed causing NaNs. #7 has a workaround.

rainx0r commented 1 year ago

Did some digging into this because it was really bothering me and turns out the behaviour seems somewhat expected / it's not really distrax' fault. I think the conclusion I came to is pretty much what this comment in #7 describes as well, but perhaps it'd be worth documenting here in greater detail since this issue is still open.

If you print the sampled actions in this code snippet rather than their sum, you will notice that specifically at index [0,0] the value is 1.0000001. Then, calling jnp.arctanh() as part of the inverse process of the Tanh bijector, you get a nan.

Obviously such a value is outside the range of tanh and shouldn't occur but it does because of numerical precision. Switching the precision to 64bit with jax.config.update("jax_enable_x64", True), you don't get such values and the code snippet works fine.

As a sidenote, the reason custom_log_prob() returns a value here is because it doesn't actually take the arctanh() of the sampled action. If you look closely at the snippet, the actual function discards the gaussian_action argument it takes, and reinitialises it by sampling from a normal distribution, which is wrong (it only works if the same rng key was used for the actions whose log prob is being computed). If you cut that line out, it too returns nan just like tfp and distrax.

Therefore, this isn't something that can be fixed on distrax' end. The reason #7's workaround works is because it computes the log prob of the sampled actions using the pre-tanh value (which is readily available since the operation includes a forward sampling pass) and the numerical precision never becomes a problem. Calling log_prob() on pre-sampled actions, however, (so the pre-tanh value isn't readily available), requires a call to arctanh() and results in the problem above unless 64-bit precision is used.

To conclude, the ways around this I can think of are to either: