DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
9.26k stars 1.71k forks source link

[Question] Making the squash transform optional for SAC/TD3? #195

Closed ManifoldFR closed 3 years ago

ManifoldFR commented 4 years ago

Describe the bug

Some environments, such as PyBullet's implementation DeepMimic environment HumanoidDeepMimicBulletEnv-v1 (which I've successfully trained on using this library's implementation of PPO), don't seem to respond well to output squashing using the tanh function. I got poor behavior using SAC, I think due to this: large areas of the action space (mostly the edges of the [-1,1] box) contain actions that almost always lead to failure states because of that environment's early termination of episodes, but the squashed Gaussian can have lots of mass in these areas during learning. For the DeepMimic environment, these actions correspond to target poses that send the character flying away.

I've tried addressing this without removing the tanh yet:

araffin commented 4 years ago

hich I've successfully trained on using this library's implementation of PPO), don't seem to respond well to output squashing using the tanh function.

I've tried addressing this without removing the tanh yet:

Why didn't you defined a custom Actor and check that this was actually the problem? (replacing the SquashedGaussianDistribution by a diagonal gaussian distribution, you may need to do some edit if you want the log_std to depend on the state, but at first I would try without)

If you don't squash the output, you will still need to clip the actions to account for the action limits (and the agent will not be aware of it which makes things even worse usually).

ManifoldFR commented 4 years ago

I haven't tried a custom Actor yet precisely because using tanh is much more principled and I'd be interested in making it work, though DeepMimic's own PPO implementation uses clipping (and a small penalty loss for violating the limits).

you may need to do some edit if you want the log_std to depend on the state, but at first I would try without)

I guess I missed that detail about the SAC implementation, that the noise always depends on the state even without SDE Thanks !

araffin commented 4 years ago

I guess I missed that detail about the SAC implementation, that the noise always depends on the state even without SDE

For a more detailed answer ;) : https://github.com/hill-a/stable-baselines/issues/652#issuecomment-575900359

ManifoldFR commented 4 years ago

Thanks

I guess I missed that detail about the SAC implementation, that the noise always depends on the state even without SDE

For a more detailed answer ;) : hill-a/stable-baselines#652 (comment)

Thanks for this answer, it addressed exactly what I was wondering about constant vs function stds !

araffin commented 4 years ago

Any news regarding that issue? (you can also try to squash the output for PPO, this feature is implemented)

ManifoldFR commented 4 years ago

I wrote a custom policy class for SAC with the squashing removed and tested it. The initial actions were a bit all over the place, I'd have to do more testing later but don't have the time at the moment

For DeepMimic, I was able to get the agent to train by modifying the network init a bit more:


def init_sac_policy(policy: SACPolicy,
                    output_scale: float = 0.01,
                    log_std_init=-3):
    """Initialize a policy for Soft Actor-Critic: orthogonal init
    with a given output scale.
    """
    actor: Actor = policy.actor
    # critic: ContinuousCritic = policy.critic
    # critic_target: ContinuousCritic = policy.critic_target
    actor_scale = output_scale
    module_gains = {
        actor.features_extractor: np.sqrt(2),
        actor.latent_pi: np.sqrt(2),
        actor.mu: actor_scale,
    }
    if not actor.use_sde:
        module_gains[actor.log_std] = output_scale

    for module, gain in module_gains.items():
        print(colored(f"Init {module} w/ gain {gain}", "yellow"))
        module.apply(partial(policy.init_weights, gain=gain))

    if not actor.use_sde:
        actor.log_std.bias.data.fill_(log_std_init)