EmptyJackson / groove

Official implementation of the NeurIPS 2023 paper "Discovering General Reinforcement Learning Algorithms with Adversarial Environment Design"
Apache License 2.0
23 stars 4 forks source link

Meta-Learned RL Objective Functions in JAX

GROOVE is the official implementation of the following publications:

  1. Discovering General Reinforcement Learning Algorithms with Adversarial Environment Design, NeurIPS 2023 [ArXiv | NeurIPS | Twitter]
    • Learned Policy Gradient (LPG),
    • Prioritized Level Replay (PLR),
    • General RL Algorithms Obtained Via Environment Design (GROOVE),
    • Grid-World environment from the LPG paper.
  2. Discovering Temporally-Aware Reinforcement Learning Algorithms, ICLR 2024 [ArXiv]
    • Temporally-Aware LPG (TA-LPG),
    • Evolutionary Strategies (ES) with antithetic task sampling.

All scripts are JIT-compiled end-to-end and make extensive use of JAX-based parallelization, enabling meta-training in under 3 hours on a single GPU!

Update (April 2023): Misreported LPG ES hyperparameters in repo + paper, specifically initial learning rate (1e-4 -> 1e-2) and sigma (3e-3 -> 1e-1). Now updated.

Setup | Running experiments | Citation

Setup

Requirements

All requirements are found in setup/, with requirements-base.txt containing the majority of packages, requirements-cpu.txt containing CPU packages, and requirements-gpu.txt containing GPU packages.

Some key packages include:

Local installation (CPU)

pip install $(cat setup/requirements-base.txt setup/requirements-cpu.txt)

Docker installation (GPU)

  1. Build docker image

    cd setup/docker & ./build_gpu.sh & cd ../..
  2. (To enable WandB logging) Add your account key to setup/wandb_key:

    echo [KEY] > setup/wandb_key

Running experiments

Meta-training is executed with python3.8 train.py, with all arguments found in experiments/parse_args.py. Argument Description
--env_mode [env_mode] Sets the environment mode (below).
--num_agents [agents] Sets the meta-training batch size.
--num_mini_batches [mini_batches] Computes each update in sequential mini-batches, in order to execute large batches with little memory. RECOMMENDED: lower this to the smallest value that fits in memory.
--debug Disables JIT compilation.
--log --wandb_entity [entity] --wandb_project [project] Enables logging to WandB.

Grid-World environments

Environment mode Description Lifetime (# of updates)
tabular Five tabular levels from LPG Variable
mazes Maze levels from MiniMax 2500
all_shortlife Uniformly sampled levels 250
all_vrandlife Uniformly sampled levels 10-250 (Log-sampled)

Examples

Experiment Command Example run (WandB)
LPG (meta-gradient) python3.8 train.py --num_agents 512 --num_mini_batches 16 --train_steps 5000 --log --wandb_entity [entity] --wandb_project [project] Link
GROOVE LPG with --score_function alg_regret (algorithmic regret is computed every step due to end-to-end compilation, so currently very inefficient) TBC
TA-LPG LPG with --num_mini_batches 8 --train_steps 2500 --use_es --lifetime_conditioning --lpg_learning_rate 0.01 --env_mode all_vrandlife TBC

Docker

To execute CPU or GPU docker containers, run the relevant script (with the GPU index as the first argument for the GPU script).

./run_gpu.sh [GPU id] python3.8 train.py [args]

Citation

If you use this implementation in your work, please cite us with the following:

@inproceedings{jackson2023discovering,
    author={Jackson, Matthew Thomas and Jiang, Minqi and Parker-Holder, Jack and Vuorio, Risto and Lu, Chris and Farquhar, Gregory and Whiteson, Shimon and Foerster, Jakob Nicolaus},
    booktitle = {Advances in Neural Information Processing Systems},
    title = {Discovering General Reinforcement Learning Algorithms with Adversarial Environment Design},
    volume = {36},
    year = {2023}
}
@inproceedings{jackson2024discovering,
    author={Jackson, Matthew Thomas and Lu, Chris and Kirsch, Louis and Lange, Robert Tjarko and Whiteson, Shimon and Foerster, Jakob Nicolaus},
    booktitle = {International Conference on Learning Representations},
    title = {Discovering Temporally-Aware Reinforcement Learning Algorithms},
    volume = {12},
    year = {2024}
}

Coming soon