araffin / sbx

SBX: Stable Baselines Jax (SB3 + Jax)
MIT License
328 stars 32 forks source link

SB3 and SBX versions of SAC have radically different behaviours #55

Open jamesheald opened 1 day ago

jamesheald commented 1 day ago

I am following a tutorial that trains the myosuite myoHandReorient8-v0 environment using the stable baselines 3 version of SAC. The main block of code for performing training (which details the SAC parameters, hence why I'm putting it here) is:

def train(env_name, policy_name, timesteps, seed):
    """
    Trains a policy using sb3 implementation of SAC.

    env_name: str; name of gym env.
    policy_name: str; choose unique identifier of this policy
    timesteps: int; how long you want to train your policy for
    seed: str (not int); relevant if you want to train multiple policies with the same params
    """
    env = gym.make(env_name)
    env = Monitor(env)
    env = DummyVecEnv([lambda: env])
    env = VecNormalize(env, norm_obs=True, norm_reward=False, clip_obs=10.)

    net_shape = [400, 300]
    policy_kwargs = dict(net_arch=dict(pi=net_shape, qf=net_shape))

    model = SAC('MlpPolicy', env, learning_rate=linear_schedule(.001), buffer_size=int(3e5),
            learning_starts=1000, batch_size=256, tau=.02, gamma=.98, train_freq=(1, "episode"),
            gradient_steps=-1,policy_kwargs=policy_kwargs, verbose=1)

    succ_callback = SaveSuccesses(check_freq=1, env_name=env_name+'_'+seed, 
                             log_dir=f'{policy_name}_successes_{env_name}_{seed}')

    model.set_logger(configure(f'{policy_name}_results_{env_name}_{seed}'))
    model.learn(total_timesteps=int(timesteps), callback=succ_callback, log_interval=4)
    model.save(f"{policy_name}_model_{env_name}_{seed}")
    env.save(f'{policy_name}_env_{env_name}_{seed}')

When I call this train function and use the stable baselines 3 version of SAC (from stable_baselines3 import SAC), the model trains well. However, if I instead use the sbx version of SAC (from sbx import SAC), the actors loss, critic loss and entropy coefficient diverge:

image

The mujoco simulation also often becomes unstable in the SBX case:

WARNING:absl:Nan, Inf or huge value in QACC at DOF 26. The simulation is unstable. Time = 0.2380.
Simulation couldn't be stepped as intended. Issuing a reset

Naively, I would have thought that the SB3 and SBX versions of SAC would perform approximately the same for the same training parameters. Can you help me understand why this is not the case, and why parameters that work well for SB3 SAC are catastrophic for SBX SAC?

I am using stable_baselines3 2.3.2 and sbx 0.13.0.

araffin commented 1 day ago

Hello, i guess your issue is that SBX doesn't support lr schedule. This is a current limitation because of jax. You could use a lower but constant lr instead.

jamesheald commented 1 day ago

Hi @araffin,

Thanks for pointing out that difference. I hadn't seen it.

I'm not sure I understand your point about jax not supporting lr schedule though. The adam optimizer in optax, which sbx uses, does support learning rate schedules (see here). Indeed, optax has a whole zoo of schedules available. Am I missing something?

araffin commented 1 day ago

Am I missing something?

the optax schedule are done in term of gradient steps, where the SB3 schedule are using total timesteps (not known when creating the model). Last time I tried, it was not possible to do something like https://github.com/DLR-RM/stable-baselines3/blob/512eea923afad6f6da4bb53d72b6ea4c6d856e59/stable_baselines3/common/base_class.py#L300C13-L300C33

https://github.com/DLR-RM/stable-baselines3/blob/512eea923afad6f6da4bb53d72b6ea4c6d856e59/stable_baselines3/common/utils.py#L68-L77

araffin commented 1 day ago

Looking at https://github.com/google-deepmind/optax/issues/4, re-creating the optimizer using the previous state would be an option (although it sounds a bit overkill).

jamesheald commented 21 hours ago

Things are more stable when I use an optax linear learning rate schedule, but the performance is still bad.

I noticed that SB3 SAC optimizes the log of the entropy coeff instead of the entropy coeff (link), as discussed here. In contrast, SBX SAC optimizes the entropy coeff (link).

I've modified SBX SAC so that it optimizes the log of the entropy coeff, as in SB3 SAC, and now the performance is good (I haven't done extensive testing to see if it is as good as SB3, but it is certainly respectable now). I've made a PR for this change, which should improve SBX SAC and make it equivalent (in this respect) to SB3 SAC.

I'm happy to look into incorporating optax schedules too, in a way that's consistent with how schedules are used in SB3. It seems like it shouldn't be difficult in principle. On this note, is there a reason why you prefer to update the learning rate based on environment time steps rather than gradient steps --- the latter seems more natural to me. For example, say you wait until the end of an episode to perform perform multiple gradients steps (e.g. as many gradient steps as time steps in the episode), then the learning rate changes abruptly at the end of each episode and is constant for all the gradient steps for that episode. In contrast, if it was updated based on gradient steps, the learning rate would gradually change. If the schedule is linear, the former will give you a piecewise constant learning rate while the latter will actually be linearly.

araffin commented 2 hours ago

why you prefer to update the learning rate based on environment time steps rather than gradient steps --- the latter seems more natural to me.

when not talking about RL, I agree it is more natural.

In RL, you don't always know how many gradient steps you are going to do. For instance with PPO, you can have some early exit if the kl divergence between old and new policy is too large. Reasoning in term of how much experience did the agent collect to decide when to change the lr is also more natural to me.

then the learning rate changes abruptly at the end of each episode and is constant for all the gradient steps for that episode

that's true for episodic RL (which is not what most people do nowadays), although I would disagree that the lr changes abruptly (the length of an episode is usually << rate at which the lr schedule changes)

I'm happy to look into incorporating optax schedules too, in a way that's consistent with how schedules are used in SB3.

that would be a nice addition =)