Shang-Fu Chen*, Hsiang-Chun Wang*, Ming-Hao Hsu, Chun-Mao Lai, Shao-Hua Sun at NTU RLL lab
This is the official PyTorch implementation of the paper "Diffusion Model-Augmented Behavioral Cloning" (ICML2024).
Python 3.7.2
or higher. All package requirements are in
requirements.txt
. To install from scratch using Anaconda, use the following
commands.conda create -n [your_env_name] python=3.7.2
conda activate [your_env_name]
pip install -r requirements.txt
cd d4rl
pip install -e .
cd ../rl-toolkit
pip install -e .
cd ..
mkdir -p data/trained_models
wandb login <YOUR_API_KEY>
and then editing config.yaml
with your W&B username and project name.dbc/ddpm.py
.dbc/main.py
for single experiment or run wandb sweep configs/<env>/<alg.yaml>
to run a wandb sweep. configs
.We specify how to train diffusion models and the location of configuration files as following:
python dbc/ddpm.py --traj-load-path expert_datasets/maze.pt --num-epoch 8000 --lr 0.0001 --hidden-dim 128
python dbc/main.py --alg dbc --bc-num-epochs 2000 --depth 3 --hidden-dim 256 --coeff 30 --coeff-bc 1 --ddpm-path data/dm/trained_models/maze_ddpm.pt --env-name maze2d-medium-v2 --lr 0.00005 --traj-load-path ./expert_datasets/maze.pt --seed 1
./wandb.sh ./configs/maze/dbc.yaml
./wandb.sh ./configs/maze/bc.yaml
python dbc/ddpm.py --traj-load-path expert_datasets/pick.pt --num-epoch 10000 --lr 0.001 --hidden-dim 1024
./wandb.sh ./configs/fetchPick/dbc.yaml
./wandb.sh ./configs/fetchPick/bc.yaml
python dbc/ddpm.py --traj-load-path expert_datasets/hand.pt --num-epoch 10000 --lr 0.00003 --hidden-dim 2048
./wandb.sh ./configs/hand/dbc.yaml
./wandb.sh ./configs/hand/bc.yaml
python dbc/ddpm.py --traj-load-path expert_datasets/halfcheetah.pt --num-epoch 8000 --lr 0.0002 --hidden-dim 1024
./wandb.sh ./configs/halfcheetah/dbc.yaml
./wandb.sh ./configs/halfcheetah/bc.yaml
python dbc/ddpm.py --traj-load-path expert_datasets/walker.pt --num-epoch 8000 --lr 0.0002 --hidden-dim 1024
./wandb.sh ./configs/walker/dbc.yaml
./wandb.sh ./configs/walker/bc.yaml
python dbc/ddpm.py --traj-load-path expert_datasets/ant.pt --num-epoch 20000 --lr 0.0002 --hidden-dim 1024 --norm False
./wandb.sh ./configs/antReach/dbc.yaml
./wandb.sh ./configs/antReach/bc.yaml
rl-toolkit/rlf/algos/il/dbc.py
: Algorithm of our methodrl-toolkit/rlf/algos/il/bc.py
: Algorithm of BCd4rl/d4rl/pointmaze/maze_model.py
: Maze2D taskdbc/envs/fetch/custom_fetch.py
: Fetch Pick task.dbc/envs/hand/manipulate.py
: Hand Rotate task.@inproceedings{
chen2024diffusion,
title={Diffusion Model-Augmented Behavioral Cloning},
author={Shang-Fu Chen and Hsiang-Chun Wang and Ming-Hao Hsu and Chun-Mao Lai and Shao-Hua Sun},
booktitle={Forty-first International Conference on Machine Learning},
year={2024},
url={https://openreview.net/forum?id=OnidGtOhg3}
}