EmptyJackson / policy-guided-diffusion

Official implementation of the RLC 2024 paper "Policy-Guided Diffusion"
MIT License
117 stars 7 forks source link

Policy-Guided Diffusion

animated

The official implementation of Policy-Guided Diffusion - built by Matthew Jackson and Michael Matthews.

Diffusion and agent training is implemented entirely in Jax, with extensive JIT-compilation and parallelization!

Update (28/06/24): Added WandB report with diffusion and agent model training logs.

Running experiments

Diffusion and agent training is executed with python3 train_diffusion.py and python3 train_agent.py, with all arguments found in util/args.py.

Docker installation

  1. Build docker image

    cd docker && ./build.sh && cd ..
  2. (To enable WandB logging) Add your account key to docker/wandb_key:

    echo [KEY] > docker/wandb_key

Launching experiments

./run_docker.sh [GPU index] python3.9 [train_script] [args]

Diffusion training example:

./run_docker.sh 0 python3.9 train_diffusion.py --log --wandb_project diff --wandb_team flair --dataset_name walker2d-medium-v2

Agent training example:

./run_docker.sh 6 python3.9 train_agent.py --log --wandb_project agents --wandb_team flair --dataset_name walker2d-medium-v2 --agent iql

Citation

If you use this implementation in your work, please cite us with the following:

@misc{jackson2024policyguided,
      title={Policy-Guided Diffusion},
      author={Matthew Thomas Jackson and Michael Tryfan Matthews and Cong Lu and Benjamin Ellis and Shimon Whiteson and Jakob Foerster},
      year={2024},
      eprint={2404.06356},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}