vwxyzjn / cleanrl

High-quality single file implementation of Deep Reinforcement Learning algorithms with research-friendly features (PPO, DQN, C51, DDPG, TD3, SAC, PPG)
http://docs.cleanrl.dev
Other
5.26k stars 602 forks source link

SAC jax #300

Open araffin opened 1 year ago

araffin commented 1 year ago

Description

Missing: benchmark and doc

Adapted from https://github.com/araffin/sbx Report (3 seeds on 3 MuJoCo envs): https://wandb.ai/openrlbenchmark/cleanrl/reports/SAC-jax---VmlldzoyODM4MjU0

Types of changes

Checklist:

If you are adding new algorithms or your change could result in performance difference, you may need to (re-)run tracked experiments. See https://github.com/vwxyzjn/cleanrl/pull/137 as an example PR.

vercel[bot] commented 1 year ago

The latest updates on your projects. Learn more about Vercel for Git ↗︎

Name Status Preview Comments Updated (UTC)
cleanrl ✅ Ready (Inspect) Visit Preview 💬 Add feedback Jun 15, 2023 6:05pm
araffin commented 1 year ago

@vwxyzjn tests fails because ModuleNotFoundError: No module named 'pygame', not sure why it worked before...

vwxyzjn commented 1 year ago

ModuleNotFoundError: No module named 'pygame' looks really weird... so I investigated a bit further into it. Instead of running poetry lock, I ran poetry add tensorflow-probability and poetry update flax and that seems to make things work.

It turns out the culprit is the following changes

-classic_control = ["pygame (==2.1.0)"]
+classic-control = ["pygame (==2.1.0)"]

We install pygame by pip install gym[classic_control] under the hood with poetry, but for some reason the key of the extra was changes 😓

araffin commented 1 year ago

@vwxyzjn I think I'm done for the implementation, I added support for constant entropy coeff and for deterministic eval. I would be happy to receive help for the documentation ;)

araffin commented 1 year ago

Is there any corresponding wandb report ?

I only did some runs after I did the initial ones without the deterministic eval. You can see for instance in cleanrl, search for "HalfCheetah-v2" or "Ant-v2": https://wandb.ai/openrlbenchmark/cleanrl?workspace=user-araffin it is under the name "sac_continuous_actions_jax" ("sac_jax" was without deterministic eval)

So, as expected, performance is slightly better than the training performance with stochastic policy. I would also need help to update the doc. W B Chart 10_26_2022, 12 28 36 PM W B Chart 10_26_2022, 12 29 32 PM

vwxyzjn commented 1 year ago

Hey @araffin, thanks for the implementation! The results look great :) I am a little overwhelmed with a bunch of stuff these two weeks, so my apologies in advance for the delay in my review.

araffin commented 1 year ago

so my apologies in advance for the delay in my review.

no worry, take the time that you need.

vwxyzjn commented 1 year ago

Btw here is a snippet to plot for making our lives easier. By specifying the following

    env_ids = [
        "HalfCheetah-v2",
        "Hopper-v2",
        "Walker2d-v2",
    ]
    exp_names = [
        "sac_jax",
        "sac_continuous_action",
        "sac_continuous_action_deter_eval",
    ]
    runsetss = []
    for exp_name in exp_names:
        runsetss += [
            [
                Runset(
                    name=f"CleanRL's {exp_name}",
                    filters={"$and": [{"config.env_id.value": env_id}, {"config.exp_name.value": exp_name}]},
                    entity="openrlbenchmark",
                    project="cleanrl",
                    groupby="exp_name",
                )
                for env_id in env_ids
            ]
        ]

it can generate this wandb report

image

and the following image

image
vwxyzjn commented 1 year ago

Perhaps it's because in https://github.com/vwxyzjn/cleanrl/pull/217 I implemented my own normal distribution I am trying to do the same for SAC...

However if I replaced

def actor_loss(params):
            dist = TanhTransformedDistribution(
                tfd.MultivariateNormalDiag(loc=action_mean, scale_diag=jnp.exp(action_logstd)),
            )
            actor_actions = dist.sample(seed=subkey)
            log_prob = dist.log_prob(actor_actions).reshape(-1, 1)

with the log probability taken from https://github.com/openai/baselines/blob/9b68103b737ac46bc201dfb3121cfa5df2127e53/baselines/common/distributions.py#L238-L241

def actor_loss(params):
            action_mean, action_logstd = actor.apply(params, observations[0:1])
            action_std = jnp.exp(action_logstd)
            actor_actions = action_mean + action_std * jax.random.normal(subkey, shape=action_mean.shape)
            log_prob = -0.5 * ((actor_actions - action_mean) / action_std) ** 2 - 0.5 * jnp.log(2.0 * jnp.pi) - action_logstd
            log_prob = log_prob.sum(axis=1, keepdims=True)
            actor_actions = jnp.tanh(actor_actions)

things kind of fall catastrophically... I felt that maybe implementing our own would bring greater transparency but maybe not be necessary...

vwxyzjn commented 1 year ago

Aha! I got it, it's supposed to be the following

image
            action_mean, action_logstd = actor.apply(params, observations)
            action_std = jnp.exp(action_logstd)
            actor_actions = action_mean + action_std * jax.random.normal(subkey, shape=action_mean.shape)
            log_prob = -0.5 * ((actor_actions - action_mean) / action_std) ** 2 - 0.5 * jnp.log(2.0 * jnp.pi) - action_logstd
            actor_actions = jnp.tanh(actor_actions)
            log_prob -= jnp.log((1 - jnp.power(actor_actions, 2)) + 1e-6)
            log_prob = log_prob.sum(axis=1, keepdims=True)

Interestingly, the paper seems to say our implementation should have been the following (with the summation)

log_prob -= jnp.log((1 - jnp.power(actor_actions, 2)) + 1e-6).sum(axis=-1).reshape(-1, 1)

but empirically, it doesn't perform as well... @dosssman any thoughts?

araffin commented 1 year ago

Interestingly, the paper seems to say our implementation should have been the following (with the summation)

Not sure to follow the difference...

You can take a look at how we do it in SB3, I think it is what is described: https://github.com/DLR-RM/stable-baselines3/blob/c4f54fcf047d7bf425fb6b88a3c8ed23fe375f9b/stable_baselines3/common/distributions.py#L222-L226

vwxyzjn commented 1 year ago

I tried to implement the probability distribution ourselves https://github.com/vwxyzjn/cleanrl/pull/300/commits/0cf0e9e8afe56c54485f80c74dfd992d1a5a79fc, but hit a performance regression.

image

Looking into the issue deeper, I couldn't quite understand how TanhTransformedDistribution works. Could someone take a look at https://gist.github.com/vwxyzjn/331f896b79d3f829fdfa575be666d2d8, which generates

manually sample actions, manually calculate log prob
  action=2.561650514602661, logprob=55.152984619140625
manually sample actions, calculate log prob from TanhTransformedDistribution
  action=2.561650514602661, logprob=nan
sample actions from `TanhTransformedDistribution`, calculate log prob from TanhTransformedDistribution
  action=2.7475833892822266, logprob=66.45195770263672
sample actions from `TanhTransformedDistribution`, manually calculate log prob
  action=2.7475833892822266, logprob=-inf

I am quite puzzled. TanhTransformedDistribution seems like quite a black box to me. Because tensorflow_probability is written in tensorflow, there is no meaningful code trace in the IDE to understand what's happening inside... And tfp's docs seems to have some issues (e.g., the "view source code on Github" button in https://www.tensorflow.org/probability/api_docs/python/tfp/distributions/MultivariateNormalDiag is broken). Maybe we shouldn't use anything from tfp?

araffin commented 1 year ago

@vwxyzjn run the code with JAX_ENABLE_X64=True and it will solve your issue ;) (results are still slightly different, but that's probably expected, try with different random seeds) JIT_DISABLE_JIT=1 already partially solves the issue.

I guess the answer to your question is called numerical precision ;).

EDIT: the code from tf distribution is here: https://github.com/tensorflow/probability/blob/bcdf53024ef9f35d81be063093ccfb3a762dab3f/tensorflow_probability/python/bijectors/tanh.py#L70-L81

  # We implicitly rely on _forward_log_det_jacobian rather than explicitly
  # implement _inverse_log_det_jacobian since directly using
  # `-tf.math.log1p(-tf.square(y))` has lower numerical precision.

  def _forward_log_det_jacobian(self, x):
    #  This formula is mathematically equivalent to
    #  `tf.log1p(-tf.square(tf.tanh(x)))`, however this code is more numerically
    #  stable.
    #  Derivation:
    #    log(1 - tanh(x)^2)
    #    = log(sech(x)^2)
    #    = 2 * log(sech(x))
    #    = 2 * log(2e^-x / (e^-2x + 1))
    #    = 2 * (log(2) - x - log(e^-2x + 1))
    #    = 2 * (log(2) - x - softplus(-2x))
    return 2. * (np.log(2.) - x - tf.math.softplus(-2. * x))
araffin commented 1 year ago

run the code with JAX_ENABLE_X64=True and it will solve your issue ;) (results are still slightly different, but that's probably expected, try with different random seeds)

@vwxyzjn as a follow up, if you remove the + 1e-6 in your code, you get the same results. Btw, why did you use 1e-6 and not a smaller value?

EDIT: I don't know why precommit fails, it does work locally

Howuhh commented 1 year ago

@araffin 1e-6 used on most popular SAC pytorch implementations, I also use it on my research for some reason (and in CORL). I think it's more a matter of reproducibility.

ffelten commented 1 year ago

Hi, is there any update/blocking thing on this?

araffin commented 1 year ago

@vwxyzjn I would need your help again to update the lockfile, I tried to do it locally and poetry destroyed my conda env...