[🌏 Webpage
] [📕 Paper
] [💬 ICLR 2024 OpenReview (top 5% spotlight)
]
Official code release for the ICLR 2024 paper 👇
Bhatt A.*, Palenicek D.*, Belousov B., Argus M., Amiranashvili A., Brox T., Peters J.
Execute the following commands to set up a conda environment to run experiments
conda create -n crossq python=3.11.5
conda activate crossq
conda install -c nvidia cuda-nvcc=12.3.52
pip install -e .
pip install "jax[cuda12_pip]==0.4.19" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
The main entry point for running experiments is train.py
. You can configure experiments with the appropriate environment and agent flags. For more info run python train.py --help
.
To train with WandB logging, run the following command to train a CrossQ agent on the Humanoid-v4
environment with seed 9
, which will log the results to your WandB entity and project:
python train.py -algo crossq -env Humanoid-v4 -seed 9 -wandb_mode 'online' -wandb_entity my_team -wandb_project crossq
To train without WandB logging, run the following command, and in a different terminal run tensorboard --logdir logs
to visualize training progress:
python train.py -algo crossq -env Humanoid-v4 -seed 9 -wandb_mode 'disabled'
To train on a cluster, we provide examples of slurm scripts in /slurm
to run various experiments, baselines and ablations performed in the paper on a slurm cluster.
These configurations are very cluster specific and probably need to be adjusted for your specific cluster. However, they should surve as a starting point.
To cite our paper and/or this repository in publications:
@inproceedings{
bhatt2024crossq,
title={CrossQ: Batch Normalization in Deep Reinforcement Learning for Greater Sample Efficiency and Simplicity},
author={Aditya Bhatt and Daniel Palenicek and Boris Belousov and Max Argus and Artemij Amiranashvili and Thomas Brox and Jan Peters},
booktitle={The Twelfth International Conference on Learning Representations},
year={2024},
url={https://openreview.net/forum?id=PczQtTsTIX}
}
The implementation is built upon code from Stable Baselines JAX.