instadeepai / Mava

🦁 A research-friendly codebase for fast experimentation of multi-agent reinforcement learning in JAX
Apache License 2.0
700 stars 83 forks source link

[BUG] Fix termination vs truncation mixup #951

Open sash-a opened 9 months ago

sash-a commented 9 months ago

Describe the bug

Seem like we are not using the correct termination vs truncation values, we're always using the condition termination or truncation (timestep.last()) when we often want to use the condition of only termination (1 - discount). It especially tricky in the recurrent systems.

Expected behavior

What we should do is that when calculating advantages we should use termination (1 - discount) and in the recurrent systems when passing inputs to the networks during training we should use termination or truncation in order to correctly reset the hidden state.

Possible Solution

Always put 1-discount in the PPOTimestep.done and always put timestep.last() in the RNNLearnerState.done.

To avoid issues like this in future I think we should rename RnnLearnerState.done to RnnLearnerState.truncation.

Looks like there are a couple places where we use PPOTimestep.done when it should be RNNLearnerState.done so we'd have to go through and make sure we're always using the correct one. An example is here and here where we're using the PPOTimestep.done (which should be 1 - discount) in order to reset the hidden state, instead we should pass in RnnLearnerState.truncation to the loss functions and use that.

qizhg commented 5 months ago

Hi @sash-a, is this fixed?

sash-a commented 5 months ago

Right now no as all the environments we currently use there is no truncation and so there's no issue. However, it is high on the list, but unfortunately the whole team is off this week and we are quite busy with some other research right now, but we will try to fix this as soon as possible.

If you have capacity we're more than happy to accept pull requests.

Note this is only an issue in the PPO systems

qizhg commented 5 months ago

Thanks @sash-a I see the rec_iql.py in the current develop branch is distinguishing bewteeen termination vs truncation. Could you compare this issue in recurrent PPO vs recurrent IQL?

Also, is the fix/term-vs-trunc branch a good fix?

sash-a commented 5 months ago

Yup, both IQL and SAC handle it correctly. It's only the PPO systems that have the issue. All that needs to happen in ppo is that discounts need to be used when calculating advantage and timestep.last() should be used when resetting the hidden state.

Yup I'm pretty sure that branch works, the problem is that it just got very out of date with develop and then it became difficult to merge and we got a bit busy, but it should work, it's just missing some of Mava's latest features, but the algorithm itself should be correct.

qizhg commented 5 months ago

@sash-a Thanks and have a good day!

sash-a commented 1 month ago

After some experimentation this helps in certain envs but hurts others and needs quite a lot more investigation, unfortunately I don't have the time right now. The current progress is on the fix/term-vs-trunc2 branch. Aiming to get back to this in a couple months.

I see that this PR on CleanRL seems to agree that this is important for certain envs, but not all, which makes sense as not all envs truncate.

My current solution to the GAE is to use both terminal and truncated which I see is the same as the above issue:

    def _calculate_gae(
            traj_batch: PPOTransition, last_val: chex.Array
        ) -> Tuple[chex.Array, chex.Array]:
            """Calculate the GAE."""

            def _get_advantages(gae_and_next_value: Tuple, transition: PPOTransition) -> Tuple:
                """Calculate the GAE for a single transition."""
                gae, next_value = gae_and_next_value
                term, trunc, value, reward = (
                    transition.terminal,
                    transition.truncated,
                    transition.value,
                    transition.reward,
                )

                gamma = config.system.gamma
                delta = reward + gamma * next_value * (1 - term) - value
                gae = delta + gamma * config.system.gae_lambda * (1 - (trunc | term)) * gae
                return (gae, value), gae

            _, advantages = jax.lax.scan(
                _get_advantages,
                (jnp.zeros_like(last_val), last_val),
                traj_batch,
                reverse=True,
                unroll=16,
            )
            return advantages, advantages + traj_batch.value

        advantages, targets = _calculate_gae(traj_batch, last_val)