vwxyzjn / cleanrl

High-quality single file implementation of Deep Reinforcement Learning algorithms with research-friendly features (PPO, DQN, C51, DDPG, TD3, SAC, PPG)
http://docs.cleanrl.dev
Other
5.02k stars 575 forks source link

Add Polyak update to DQN #346

Closed manjavacas closed 1 year ago

manjavacas commented 1 year ago

Problem Description

Checklist

Current Behavior

Currently, DQN implementation do a hard update of the target network. However, it is possible to perform soft updates by using a soft update coefficient, between 0 and 1 (Polyak update).

Expected Behavior

Soft updates can increase the stability of learning, as detailed in the original DDPG paper. This is because the target values are constrained to change slowly.

Although this idea came after the original implementation of DQN, it is equally applicable to this algorithm.

Finally, this is a solution implemented in other reference libraries such as StableBaselines3, although I would understand that it is not intended to be added for simplicity and adherence to the original DQN implementation.

Possible Solution

In the current DQN implementation, substitute:

    # update the target network
    if global_step % args.target_network_frequency == 0:
        target_network.load_state_dict(q_network.state_dict())

by:

    # update the target network
    if global_step % args.target_network_frequency == 0:
        for target_network_param, q_network_param in zip(target_network.parameters(), q_network.parameters()):
            target_network_param.data.copy_(
                polyak_coeff * q_network_param.data + (1. - polyak_coeff) * target_network_param.data)
vwxyzjn commented 1 year ago

Hi thanks for raising this issue. This sounds like a good idea, especially since we are already doing polyak updates in https://github.com/vwxyzjn/cleanrl/blob/3f5535cab409a34e9f071c10b96a234925d8a8d5/cleanrl/dqn_jax.py#L231 (optax docs on optax.incremental_update).

Feel free to make a PR.