kvablack / susie

Code for subgoal synthesis via image editing
https://rail-berkeley.github.io/susie
MIT License
109 stars 14 forks source link

susie

Code for the paper Zero-Shot Robotic Manipulation With Pretrained Image-Editing Diffusion Models.

This repository contains the code for training the high-level image-editing diffusion model on video data. For training the low-level policy, head over to the BridgeData V2 repository --- we use the gc_ddpm_bc agent, unmodified, with an action prediction horizon of 4 and the delta_goals relabeling strategy.

For integration with the CALVIN simulator and reproducing our simulated results, see our fork of the calvin-sim repo and the corresponding documentation in the BridgeData V2 repository.

Model Weights

The UNet weights for our best-performing model, trained on BridgeData and Something-Something for 40k steps, are hosted on HuggingFace. They can be loaded using FlaxUNet2DConditionModel.from_pretrained("kvablack/susie", subfolder="unet"). Use with the standard Stable Diffusion v1-5 VAE and text encoder.

Here's a quickstart for getting out-of-the-box subgoals using this repo:

from susie.model import create_sample_fn
from susie.jax_utils import initialize_compilation_cache
import requests
import numpy as np
from PIL import Image

initialize_compilation_cache()

IMAGE_URL = "https://rail.eecs.berkeley.edu/datasets/bridge_release/raw/bridge_data_v2/datacol2_toykitchen7/drawer_pnp/01/2023-04-19_09-18-15/raw/traj_group0/traj0/images0/im_12.jpg"

sample_fn = create_sample_fn("kvablack/susie")
image = np.array(Image.open(requests.get(IMAGE_URL, stream=True).raw).resize((256, 256)))
image_out = sample_fn(image, "open the drawer")

# to display the images if you're in a Jupyter notebook
display(Image.fromarray(image))
display(Image.fromarray(image_out))