openvla / openvla

OpenVLA: An open-source vision-language-action model for robotic manipulation.
MIT License
1.33k stars 170 forks source link

Un-normalizing statistics for Google robot tasks #21

Closed HuFY-dev closed 4 months ago

HuFY-dev commented 4 months ago

Hi, I'm playing with the OpenVLA model and I want to evaluate this model on SimplerEnv's Google robot tasks. Since there is a unnorm_key argument in the predict_action method, I assume that that's related to some dataset-specific bias terms and has to be changed if I use some other datasets. However I saw that for the openvla-7b model, there's no google-related keys in the norm_stats dictionary. Does that mean I have to retrain/fine-tune the model on Google tasks? Also, I saw you did some Google robot evaluations in the paper, have you released related code? Thanks in advance for your help!

HuFY-dev commented 4 months ago

Never mind, I think I found the unnorm_key for Google which is fractal20220817_data. I also want to make sure that I should use bridge_orig for SimplerEnv's WidowX tasks right?

HuFY-dev commented 4 months ago

I'm struggling to get OpenVLA to work on this kind of scenarios:

Task: pick up spoon and place on towel

output

I was using the bridge_orig as the unnorm key, but that might be not the correct option. Do you know which key I should choose?

HuFY-dev commented 4 months ago

Hi @siddk @kpertsch @moojink, sorry to ping you guys, but do you have any clues what unnorm key I should use in the above scenario? Thanks in advance for the help!

kpertsch commented 4 months ago

Hey sorry for the late reply! The unnorm_keys you are using are correct for both Bridge and Google Robot. When I tried SIMPLER before, OpenVLA seemed to work pretty well on the Google robot tasks and less well on the Bridge tasks, since there the visual real-to-sim domain gap is wider. Can you try the google robot tasks?

I am pasting my SIMPLER script below -- it won't quite run out of the box since it's not yet adapted to the release code (we need a bit of time to clean this up), but figured it could be helpful for you to see the sticky action logic required for the Google Robot (this is copied from SIMPLER Octo eval, which used the same logic).

The reason Octo & OpenVLA need sticky gripper is that during their training we convert the "relative gripper actions" -- +1 for opening, -1 for closing -- into "absolute gripper actions" -- +1 for opened, -1 for closed. So during inference we need to convert them back to relative, but need to apply the relative opening / closing actions for multiple timesteps in a row for the gripper to actually fully close.

Hope this helps your investigations!

"""
eval_model_in_simpler_env.py

Runs a model checkpoint in a simulated SIMPLER environment.

Usage:
    python experiments/robot/libero/eval_model_in_simpler_env.py \
        --model.type <VLM_TYPE> \
        --pretrained_checkpoint <CHECKPOINT_PATH>
"""

import os
import sys
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Union

import draccus
import numpy as np
import tqdm

import wandb
from prismatic.conf import ModelConfig, ModelRegistry

# TODO (moojink) Hack so that the interpreter can find experiments.robot
sys.path.append("../..")
from experiments.robot.simpler.simpler_utils import (
    get_simpler_env,
    get_simpler_img,
    process_simpler_action,
)
from experiments.robot.utils import (
    get_action,
    get_image_resize_size,
    get_model,
)

assert "MS2_ASSET_DIR" in os.environ, (
    "Environment variable MS2_ASSET_DIR not set. "
    "Usage: `MS2_ASSET_DIR=./ManiSkill2_real2sim/data python test_real2sim.py ...`"
)
DATE_TIME = time.strftime("%Y_%m_%d-%H_%M_%S")

@dataclass
class GenerateConfig:
    # fmt: off

    # Pre-trained model class
    model: ModelConfig = field(
        default_factory=ModelConfig.get_choice_class(ModelRegistry.REPRODUCTION_7B.model_id)
    )
    model_family: str = "llava"

    # Model Parameters
    pretrained_checkpoint: Union[str, Path] = Path(
        "/shared/karl/models/open_vla/siglip-224px+mx-oxe-magic-soup+n8+b32+x7/"
        "checkpoints/step-152500-epoch-27-loss=0.1637.pt"
    )

    # Choose from 'bridge_orig' / 'fractal20220817_data'
    unnorm_key: str = "bridge_orig"                             # Dataset name for action unnormalization
    center_crop: bool = False                                   # Center crop? (if trained w/ random crop image aug)

    # Task to eval. For options, run: import simpler_env; simpler_env.ENVIRONMENTS
    task: str = "widowx_spoon_on_towel"

    # SIMPLER-related args
    sticky_gripper_num_repeat: int = 10                         # Number of steps for Google Robot sticky gripper action
    num_trials: int = 5                                         # Number of rollouts per task
    num_save_videos: int = 5                                    # Number of videos to be logged per task
    video_temp_subsample: int = 1                               # Temporal subsampling to make videos shorter

    # Weights & Biases
    wandb_project: str = "openvla"                              # Name of W&B project to log to (use default!)
    wandb_entity: str = "stanford-voltron"                      # Name of entity to log under

    # HF Hub Credentials (for LLaMa-2)
    hf_token: Union[str, Path] = Path(".hf_token")              # Environment variable or Path to HF Token

    # Randomness
    seed: int = 21                                              # Random Seed (for reproducibility)
    # fmt: on

@draccus.wrap()
def eval_simpler(cfg: GenerateConfig) -> None:
    assert cfg.pretrained_checkpoint is not None, "cfg.pretrained_checkpoint must not be None!"
    assert cfg.model_family == "llava", "Only OpenVLA evaluation supported for now."

    # Load Model
    model = get_model(cfg)

    # Initialize W&B
    wandb.init(
        entity=cfg.wandb_entity,
        project=cfg.wandb_project,
        name=f"EVAL-SIMPLER-{cfg.task}-{cfg.model_family}-{DATE_TIME}",
    )

    # Get Expected Image Dimensions
    resize_size = get_image_resize_size(cfg)

    # Initialize the SIMPLER environment.
    env = get_simpler_env(cfg.task)

    # Start episodes.
    task_episodes, task_successes = 0, 0
    for episode_idx in tqdm.tqdm(range(cfg.num_trials)):

        # Reset environment.
        obs, _ = env.reset()
        language_instruction = env.get_language_instruction()
        print(f"\nTask: {language_instruction}")

        # Setup.
        done, truncated = False, False
        previous_gripper_action = None
        sticky_action_is_on = False
        gripper_action_repeat = 0
        rollout_images = []
        print(f"Starting episode {episode_idx+1}...")
        while not (done or truncated):
            # Get preprocessed image.
            img = get_simpler_img(env, obs, resize_size)
            rollout_images.append(img)

            # Generate action with model.
            observation = {
                "full_image": img,
            }
            action = get_action(cfg, model, observation, language_instruction, policy_function=None)

            # Modify action to comply with SIMPLER env
            action = process_simpler_action(action)
            if "widowx" in cfg.task:
                # Map gripper [0, 1] --> [-1, 1] for Bridge env
                action[-1:] = 2.0 * (action[-1:] > 0.5) - 1.0
            elif "google_robot" in cfg.task:
                # Convert action from absolute --> relative for Google robot using "sticky" actions
                current_gripper_action = action[-1]
                if previous_gripper_action is None:
                    relative_gripper_action = np.array([0])
                else:
                    relative_gripper_action = (
                        previous_gripper_action - current_gripper_action
                    )  # google robot 1 = close; -1 = open
                previous_gripper_action = current_gripper_action

                if np.abs(relative_gripper_action) > 0.5 and sticky_action_is_on is False:
                    sticky_action_is_on = True
                    sticky_gripper_action = relative_gripper_action

                if sticky_action_is_on:
                    gripper_action_repeat += 1
                    relative_gripper_action = sticky_gripper_action

                if gripper_action_repeat == cfg.sticky_gripper_num_repeat:
                    sticky_action_is_on = False
                    gripper_action_repeat = 0
                    sticky_gripper_action = 0.0

                action[-1] = relative_gripper_action

            # Execute action in environment.
            obs, reward, done, truncated, info = env.step(action)

        task_episodes += 1
        task_successes += float(done)
        print(f"Success rate: {float(task_successes) / float(task_episodes)}")
        if episode_idx < cfg.num_save_videos:
            # Save rollout GIF.
            wandb.log(
                {
                    f"video_{episode_idx}": wandb.Video(
                        np.array(rollout_images[:: cfg.video_temp_subsample]).transpose(0, 3, 1, 2)
                    )
                }
            )

    # Log and update total metrics
    wandb.log(
        {
            "success_rate": float(task_successes) / float(task_episodes),
            "num_episodes": task_episodes,
        }
    )

if __name__ == "__main__":
    eval_simpler()
HuFY-dev commented 4 months ago

Thank you!

zhou-pig commented 2 months ago

Thank you!

Hello, my friend, I have the same problem as you. No matter what unnorm_key I choose, using OpenVLA doesn't work well in simpler env. Have you solved the problem?