matomatical / jaxgmg

JAX-based environments and RL baselines for studying goal misgeneralisation.
MIT License
2 stars 0 forks source link

37 Implementation details of PPO #17

Open matomatical opened 1 month ago

matomatical commented 1 month ago

Our baselines use a PPO algorithm that is adapted from PureJaxRL. But it doesn't appear to stick to all of the relevant implementation details from Huang et al., 2022 (henceforth 37Details):

We're not necessarily trying to 'replicate PPO' but we should consider trying each of these if/when we get a chance and see if it makes a big difference for our environments.

Core implementation details.

  1. Vectorized architecture.
    • ✅ We're training with a vector of environments.
    • 🤷 A deliberate difference is we don't carry env states between rollout phases, in UED style. This works because we have short episodes and we want to change level layouts frequently.
    • 🤔 Worth trying: 37Details explains how increasing the number of parallel envs might not slow down rollouts and therefore might increase training efficiency. I don't see why we couldn't try one episode-timeout per rollout phase. Note that PPO originally used length 128 rollout phase (though see detail (1)).
  2. Orthogonal initialization of weights and constant initialization of biases.
    • 🤔 We're just using default flax.linen initialisation which may be different from this.
  3. Adam epsilon parameter.
    • 🤔 We're just using the optax default. 37Details has some recommendations, but perhaps these are only relevant for a close replication.
  4. Adam learning rate parameter.
    • ❌ We are largely not decaying the learning rate, whereas 37Details says the prior work does and that this can help.
  5. Generalized advantage estimation.
    • ✅ We are using GAE.
    • 🤔 Are we using value bootstrap? I can't tell.
    • ✅ We are using TD lambda returns.
  6. Mini-batch updates.
    • ✅ We are shuffling the data rather than subsampling, as they suggest.
    • ✅ UPDATE: New default is to use 8 minibatches, works better than previous (1). Note that PPO originally used 4 minibatches.
  7. Normalization of advantages.
    • ✅ We are normalising the advantage estimates (at the minibatch level).
  8. Clipped surrogate objective.
    • ✅ I haven't checked the details carefully but I assume PureJaxRL got this right. This 'implementation detail' just comments on how it's a good objective...
  9. Value function loss clipping.
    • ✅ We did this.
    • ❌ 37Details points to work thinking it's actually not necessary and maybe harmful to do so. We could try turning it off.
  10. Overall loss and entropy bonus.
    • ✅ We include an entropy term in the overall loss.
    • 🤔 One work cited found the entropy term has no effect on performance in a continuous control setting.
  11. Global gradient clipping.
    • ✅ We use global grad clipping with max grad norm 0.5 as in the original.
  12. Debug variables.
    • ✅ We tracked several debug variables, including policy_loss, value_loss, entropy_loss.
    • ✅ UPDATE: We now track additionallyclipfrac (fraction of training data triggering the clipped objective) and approxkl (KL estimator using (-logratio).mean() AND ((ratio - 1) - logratio).mean(), see linked post).
    • 🤷 We also track critic clipfrac (though see 9).
  13. Shared vs. separate MLP networks for policy and value functions.
    • 🤔 They show for classic control environments an architecture with separate components for the actor and critic makes a big improvement. We don't use MLPs, we use IMPALA, and to do the same change would make our net a LOT bigger, but maybe it's worth trying and seeing if it helps.

Atari-specific implementation details. Most of these are n/a.

  1. Take a random number of no-ops on reset.
    • 🤷 We don't do anything like this, instead we sample completely different levels anyway.
  2. Skip 4 frames (repeating actions, accumulating reward, carefully avoid rendering artefacts). Speeds up the training.
    • 🤷 We don't need to do this because the environments are already neatly chunked into logical frames.
    • 🤔 If we eventually a platformer with smaller differences between frames we could consider this.
  3. Mark the end of a 'life' as the end of episode.
    • 🤷 N/A. We design environments not to have lives.
  4. Automatically take the FIRE action at the start of each episode for environments that don't start until this happens.
    • 🤷 N/A. We design environments not to need that kind of thing.
  5. Resize color frames to 84x84 pixel grayscale.
    • 🤷 A deliberate difference is that we use either boolean or RGB-colour observations. We offer good control over this.
    • 🤔 I guess if/when we train with RGB observations we could try to stick to this rough size and we could consider reducing to greyscale, though I think the RGB doesn't hurt.
  6. Clip rewards to their sign (-1, 0, or +1).
    • ✅ We basically do this, indirectly in that we only produce such rewards to begin with.
    • 🤔 We also have an option to scale reward down over episode time. I think it's off by default. I think it's worth experimenting to see its effect.
  7. Stack like 4 frames.
    • ❌ We don't offer this at the moment. It would be pretty easy to add actually via the base environment, or via each environment (and we could even save params by only stacking the dynamic channels in the Boolean obs setting). Seems worth trying also as a way to test that the LSTM is working.
  8. Again, shared vs. separate NatureCNN architecture.
    • 🤷 Again impala.
    • 🤔 We could try NatureCNN as an option btw.
  9. Image values should be in the range [0, 1]. Otherwise, KL explodes.
    • ✅ Yup, they are 0 or 1 for boolean obs and [0,1] for RGB.

Hyper-parameters from the original PPO paper (for some reason these aren't considered part of the 37 details).

  1. PPO originally used 128 steps per parallel env per rollout phase.
    • 🤷 We were using 256 and now 128.
    • 🤔 Could probably get away with less---worth trying.
  2. PPO originally used 8 parallel environments.
    • 🤷 We originally used 32, now we use 256 and it works better.
  3. GAE lambda of 0.95.
    • ✅ We use the same by default.
  4. Gamma of 0.99.
    • 🤷 We've been using 0.999.
  5. Entropy coefficient of 0.01.
    • 🤷 We've been using 0.001 instead.
  6. Learning rate schedule (see above).
  7. PPO clipping parameter 0.1.
    • ✅ We were using 0.2. I have seen both suggested. But 0.1 seems to be more stable for very long runs, so that's the new default.
  8. 4 minibatches.
    • 🤷 UPDATE: New default is to use 8 minibatches, works better than previous (1), about same/slightly better than 4.
  9. 4 epochs per learning phase (?)
    • 🤷 We've been using 5, seems fine.

Details for continuous action domains. Most of these are really N/A and not even worth writing down, with one exception.

  1. Reward scaling.
    • 🤷 We do not use 'a discount-based scaling scheme where the rewards are divided by the standard deviation of a rolling discounted sum of the rewards (without subtracting and re-adding the mean).' However to do so seems pretty unprincipled? I think we are fine not using this.

LSTM implementation details.

  1. Layer initialization for LSTM layers.
    • 🤔 In the PPO implementation, the LSTM layers' weights are initialised with std=1 and biases initialized with 0. In contrast we use flax linen defaults, which may be different.
  2. Initialize LSTM states to be zeros.
    • ✅ We are just using the flax linen default, but I confirmed this is what it does (regardless of the rng passed to the cell's initialize_carry method).
  3. Reset LSTM states at the end of the episode.
    • ✅ We do this in effect.
    • 🤷 By personal choice I didn't do it by passing the done vector to the forward pass, instead it's managed by the environment step method during the rollout phase.
  4. Prepare sequential rollouts in mini-batches.
    • ❌ We don't do this. We still shuffle the old data.
    • 🤔 But is that OK, since we also stored the LSTM state from during the rollout, so we don't need to reconstruct them? See also next detail.
    • 🤔 The only problem I can see is if we are somehow meant to actually run the loss function over a whole trajectory so that gradients can pass through the LSTM to earlier steps and it helps the LSTM learn how to use its memory effectively. NEED TO GET TO THE BOTTOM OF THIS.
  5. Reconstruct LSTM states during training.
    • ❌ We don't do this.
    • 🤔 Do we need to do this? We store the states collected during rollouts instead. This uses extra space but means we can shuffle the batches later, which might be a good thing, pending my confusions around (4) above.

There is a final detail of multi-discrete action support, which is not necessary for me at the moment.

Auxiliary implementation details (not used by original PPO implementation, but potentially useful in some situations).

  1. Clip range annealing.
    • ❌ We don't do this. We should try it.
  2. Parallelized gradient update.
    • 🤷 N/A for our compute setup.
  3. Early stopping of the policy optimizations.
    • 🤔 Needs more thought. We are having some weird issues in generalisation shifting but I don't know why.
  4. Invalid action masking.
    • 🤷 N/A since we don't have invalid actions like they do in general-purpose emulators.
    • 🤔 Actually, we do often have invalid actions in some states e.g. walking against a wall, which can be sometimes useful but are mostly not (they would never be useful if we provided an always-available 5th no-op action, see issue #16). We could consider masking these out and seeing if it helps our agents.
matomatical commented 1 month ago

Oir original minibatch shape defaults were bad.

I did a sweep over a couple of values monitoring performance, wall-clock time and GPU utilisation (on an RTX 4090). Running with the following options:

--num-parallel-envs 256 --num-env-steps-per-cycle 128 --num-minibatches-per-epoch 8

seemed to make training go quite a bit faster (40% faster or so on the RTX 4090) and performance metrics seem roughly the same or maybe a bit better.

So, we set this to the new default.

matomatical commented 1 month ago

❌ We could track several additional debug variables, namely clipfrac (fraction of training data triggering the clipped objective) and approxkl (KL estimator using (-logratio).mean() or ((ratio - 1) - logratio).mean(), see linked post).

Commit 090dcec9751fd3f85e0e0cbcf7123df4844a3ae3 introduces four new metrics:

matomatical commented 1 month ago

PPO clipping parameter 0.1.

  • ❌ We've been using 0.2. I have seen both suggested. It's probably fine, but we could double check.

Usman ran some experiments and found that 0.1 works better, solving a stability issue we had encountered.

We should use this going forward (along with lr=5e-5 instead of 5e-4).

We should still remember to try both clip range annealing and learning rate annealing at some point.