Code for the paper Zero-Shot Robotic Manipulation With Pretrained Image-Editing Diffusion Models.
This repository contains the code for training the high-level image-editing diffusion model on video data. For training the low-level policy, head over to the BridgeData V2 repository --- we use the gc_ddpm_bc
agent, unmodified, with an action prediction horizon of 4 and the delta_goals
relabeling strategy.
For integration with the CALVIN simulator and reproducing our simulated results, see our fork of the calvin-sim repo and the corresponding documentation in the BridgeData V2 repository.
scripts/
directory inside dlimp for creating TFRecords in a compatible format.pip install -r requirements.txt
to install the versions of required packages confirmed to be working with this codebase. Then, pip install -e .
. Only tested with Python 3.10. You'll also have to manually install Jax for your platform (see the Jax installation instructions). Make sure you have the Jax version specified in requirements.txt
(rather than using --upgrade
as suggested in the Jax docs).base.py
, you can start training by running python scripts/train.py --config configs/base.py:base
.scripts/robot
directory. You probably won't be able to run them, since you don't have our robot setup, but they are there for reference. See create_sample_fn
in susie/model.py
for canonical sampling code.The UNet weights for our best-performing model, trained on BridgeData and Something-Something for 40k steps, are hosted on HuggingFace. They can be loaded using FlaxUNet2DConditionModel.from_pretrained("kvablack/susie", subfolder="unet")
. Use with the standard Stable Diffusion v1-5 VAE and text encoder.
Here's a quickstart for getting out-of-the-box subgoals using this repo:
from susie.model import create_sample_fn
from susie.jax_utils import initialize_compilation_cache
import requests
import numpy as np
from PIL import Image
initialize_compilation_cache()
IMAGE_URL = "https://rail.eecs.berkeley.edu/datasets/bridge_release/raw/bridge_data_v2/datacol2_toykitchen7/drawer_pnp/01/2023-04-19_09-18-15/raw/traj_group0/traj0/images0/im_12.jpg"
sample_fn = create_sample_fn("kvablack/susie")
image = np.array(Image.open(requests.get(IMAGE_URL, stream=True).raw).resize((256, 256)))
image_out = sample_fn(image, "open the drawer")
# to display the images if you're in a Jupyter notebook
display(Image.fromarray(image))
display(Image.fromarray(image_out))