facebookresearch / minimax

Efficient baselines for autocurricula in JAX.
Apache License 2.0
172 stars 14 forks source link

PLR Reset Behaviour #6

Closed Michael-Beukman closed 2 months ago

Michael-Beukman commented 3 months ago

I was wondering about the PLRRunner's rollouts. It seems to me that its get_transition function (which is the same as the DRRunner's one) does not use the reset_state argument of env.step. I think that means that when the environment is done, the auto reset code triggers, which generates a new level randomly (i.e., a DR level). If it is the case, then it means that the levels after the first episode are DR ones, and this may cause problems for the MaxMC score calculation. In addition, it means that the agent trains on randomly-generated levels sometimes.

minqi commented 3 months ago

Yes, in the current implementation, PLRRunner will sample domain randomized levels after the first episode when each rollout dimension completes. A simple fix would be to directly call reset_state in a separate runner, though new hyperparameters would have to be determined for the updated behavior.

However, one issue with this simple fix is that resetting to the first level per rollout dimension would mean potentially training on a significantly fewer number of distinct levels, depending on the average episode length. The previous PyTorch implementation resamples from the PLR buffer itself when resetting the episode in a rollout. Implementing a similar behavior would be slightly more involved, as the PLR buffer is stateful, and the state updates with every sample (due to the staleness scores). I sketched an approach a few months back, but haven't had time to implement it.

Another possibility is to try out new heuristics for tracking staleness that are simpler for bookkeeping in JAX. I think this approach may be most practical and effective.

For now I'll look into simply resetting to the first level per rollout batch. I'll run some sweeps on this evaluation and share the update.