This repository provides code for training on BridgeData V2.
We provide implementations for the following subset of methods described in the paper:
The official implementations and papers for all the methods can be found here:
Please open a GitHub issue if you encounter problems with this code.
The raw dataset (comprised of JPEGs, PNGs, and pkl files) can be downloaded here. demos*.zip
file contains the demonstration data, and scripted*.zip
contains the data collected with a scripted policy. For training, the raw data needs to be converted into a format that is compatible with a data loader. We offer two options:
tf.data
loader. This data loader is implemented in jaxrl_m/data/bridge_dataset.py
and is used by the training script in this repo. The scripts in the data_processing
folder convert the raw data into the format required by this data loader. First, use bridgedata_raw_to_numpy.py
to convert the raw data into NumPy files. Then, use bridgedata_numpy_to_tfrecord.py
to convert the NumPy files into TFRecord files. tf.data
. We offer a pre-processed TFDS version of the dataset (downsampled to 256x256) in the tfds
folder here here. In the TFDS dataset, the trajectories are structured using the RLDS format. We recommend using the Octo data loader for loading the RLDS version of BridgeData. If you would like to reprocess BridgeData into RLDS (e.g to change the resolution or add keys), you can use this repo.To start training run the command below. Replace METHOD
with one of gc_bc
, gc_ddpm_bc
, gc_iql
, or contrastive_rl_td
, and replace NAME
with a name for the run.
python experiments/train.py \
--config experiments/configs/train_config.py:METHOD \
--bridgedata_config experiments/configs/data_config.py:all \
--name NAME
Training hyperparameters can be modified in experiments/configs/data_config.py
and data parameters (e.g. subsets to include/exclude) can be modified in experiments/configs/train_config.py
.
First, set up the robot hardware according to our guide. Install our WidowX robot controller stack from this repo.
There are two ways to interface a policy with the robot controller: the docker compose service method or the server-client method. Refer to the bridge_data_robot docs for an explanation of how to set up each method. In general, we recommend the server-client method.
For the server-client method, start the server on the robot. Then run the following commands on the client. You can specify the IP of the remote server via the --ip
flag. The default IP is localhost
(i.e the server and client are the same machine).
# Specify the path to the downloaded checkpoints directory
export CHECKPOINT_DIR=/path/to/checkpoint_dir
# For GCBC
python experiments/eval.py \
--checkpoint_weights_path $CHECKPOINT_DIR/checkpoint_300000 \
--checkpoint_config_path $CHECKPOINT_DIR/gcbc_256_config.json \
--im_size 256 --goal_type gc --show_image --blocking
# For LCBC
python experiments/eval.py \
--checkpoint_weights_path $CHECKPOINT_DIR/checkpoint_145000 \
--checkpoint_config_path $CHECKPOINT_DIR/lcbc_256_config.json \
--im_size 256 --goal_type lc --show_image --blocking
You can also specify an initial position for the end effector with the flag --initial_eep
. Similarly, use the flag --goal_eep
to specify the position of the end effector when taking a goal image.
To evaluate image-conditioned or language-conditioned methods with the docker compose service method, run eval_gc.py
or eval_lc.py
respectively in the bridge_data_v2
docker container.
Checkpoints for GCBC, LCBC, D-GCBC, GCIQL, and CRL are available here. Each checkpoint has an associated JSON file with its configuration information. The name of each checkpoint indicates whether it was trained with 128x128 images or 256x256 images.
We don't currently have a checkpoints for ACT or RT-1 available but may release them soon.
The dependencies for this codebase can be installed in a conda environment:
conda create -n jaxrl python=3.10
conda activate jaxrl
pip install -e .
pip install -r requirements.txt
For GPU:
pip install --upgrade "jax[cuda11_pip]==0.4.13" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
For TPU
pip install --upgrade "jax[tpu]==0.4.13" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
See the Jax Github page for more details on installing Jax.
This code is based on jaxrl_m from Dibya Ghosh.
If you use this code and/or BridgeData V2 in your work, please cite the paper with:
@inproceedings{walke2023bridgedata,
title={BridgeData V2: A Dataset for Robot Learning at Scale},
author={Walke, Homer and Black, Kevin and Lee, Abraham and Kim, Moo Jin and Du, Max and Zheng, Chongyi and Zhao, Tony and Hansen-Estruch, Philippe and Vuong, Quan and He, Andre and Myers, Vivek and Fang, Kuan and Finn, Chelsea and Levine, Sergey},
booktitle={Conference on Robot Learning (CoRL)},
year={2023}
}