mateuszwyszynski / octo

Octo is a transformer-based robot policy trained on a diverse mix of 800k robot trajectories.
https://octo-models.github.io/
MIT License
0 stars 0 forks source link

eval and train data format mismatch #3

Closed mateuszwyszynski closed 2 months ago

mateuszwyszynski commented 2 months ago

For the ALOHA dataset, one can generate the first proprio state using code from finetuning script:

from absl import app, flags, logging
import flax
import jax
import optax
import tensorflow as tf
import tqdm
import wandb

from octo.data.dataset import make_single_dataset
from octo.model.components.action_heads import L1ActionHead
from octo.model.components.tokenizers import LowdimObsTokenizer
from octo.model.octo_model import OctoModel
from octo.utils.jax_utils import initialize_compilation_cache
from octo.utils.spec import ModuleSpec
from octo.utils.train_utils import (
    freeze_weights,
    merge_params,
    process_text,
    TrainState,
)

dataset = make_single_dataset(
        dataset_kwargs=dict(
            name="aloha_sim_cube_scripted_dataset",
            data_dir='path/to/tensorflow_datasets',
            image_obs_keys={"primary": "top"},
            proprio_obs_key="state",
            language_key="language_instruction",
        ),
        traj_transform_kwargs=dict(
            window_size=1,
            action_horizon=50,
        ),
        frame_transform_kwargs=dict(
            resize_size={"primary": (256, 256)},
        ),
        train=True,
    )

ds_iter = dataset.iterator()
x_t = next(aloha_iter)
x_t['observation']['proprio'][0]

and the result will be the same as when one uses the code from evaluation script:

from functools import partial
import sys

from absl import app, flags, logging
import gym
import jax
import numpy as np
import wandb

sys.path.append("path/to/your/act")

# keep this to register ALOHA sim env
import copy
from typing import List

import dlimp as dl
import gymnasium as gym
import jax.numpy as jnp
import numpy as np

# need to put https://github.com/tonyzhaozh/act in your PATH for this import to work
from sim_env import BOX_POSE, make_sim_env

class AlohaGymEnv(gym.Env):
    def __init__(
        self,
        env: gym.Env,
        camera_names: List[str],
        im_size: int = 256,
        seed: int = 1234,
    ):
        self._env = env
        self.observation_space = gym.spaces.Dict(
            {
                **{
                    f"image_{i}": gym.spaces.Box(
                        low=np.zeros((im_size, im_size, 3)),
                        high=255 * np.ones((im_size, im_size, 3)),
                        dtype=np.uint8,
                    )
                    for i in ["primary", "wrist"][: len(camera_names)]
                },
                "proprio": gym.spaces.Box(
                    low=np.ones((14,)) * -1, high=np.ones((14,)), dtype=np.float32
                ),
            }
        )
        self.action_space = gym.spaces.Box(
            low=np.ones((14,)) * -1, high=np.ones((14,)), dtype=np.float32
        )
        self.camera_names = camera_names
        self._im_size = im_size
        self._rng = np.random.default_rng(seed)

    def step(self, action):
        ts = self._env.step(action)
        obs, images = self.get_obs(ts)
        reward = ts.reward
        info = {"images": images}

        if reward == self._env.task.max_reward:
            self._episode_is_success = 1

        return obs, reward, False, False, info

    def reset(self, **kwargs):
        # sample new box pose
        x_range = [0.0, 0.2]
        y_range = [0.4, 0.6]
        z_range = [0.05, 0.05]
        ranges = np.vstack([x_range, y_range, z_range])
        cube_position = self._rng.uniform(ranges[:, 0], ranges[:, 1])
        cube_quat = np.array([1, 0, 0, 0])
        BOX_POSE[0] = np.concatenate([cube_position, cube_quat])

        ts = self._env.reset()
        obs, images = self.get_obs(ts)
        info = {"images": images}
        self._episode_is_success = 0

        return obs, info

    def get_obs(self, ts):
        curr_obs = {}
        vis_images = []

        obs_img_names = ["primary", "wrist"]
        for i, cam_name in enumerate(self.camera_names):
            curr_image = ts.observation["images"][cam_name]
            vis_images.append(copy.deepcopy(curr_image))
            curr_image = jnp.array(curr_image)
            curr_obs[f"image_{obs_img_names[i]}"] = curr_image
        curr_obs = dl.transforms.resize_images(
            curr_obs, match=curr_obs.keys(), size=(self._im_size, self._im_size)
        )

        qpos_numpy = np.array(ts.observation["qpos"])
        qpos = jnp.array(qpos_numpy)
        curr_obs["proprio"] = qpos

        return curr_obs, np.concatenate(vis_images, axis=-2)

    def get_task(self):
        return {
            "language_instruction": ["pick up the cube and hand it over"],
        }

    def get_episode_metrics(self):
        return {
            "success_rate": self._episode_is_success,
        }

gym.register(
    "aloha-sim-cube-v0",
    entry_point=lambda: AlohaGymEnv(
        make_sim_env("sim_transfer_cube"), camera_names=["top"]
    ),
)

from octo.model.octo_model import OctoModel
from octo.utils.gym_wrappers import HistoryWrapper, NormalizeProprio, RHCWrapper
from octo.utils.train_callbacks import supply_rng

model = OctoModel.load_pretrained('checkpoints/aloha-0')

env = gym.make("aloha-sim-cube-v0")

# wrap env to normalize proprio
env = NormalizeProprio(env, model.dataset_statistics)

# add wrappers for history and "receding horizon control", i.e. action chunking
env = HistoryWrapper(env, horizon=1)
env = RHCWrapper(env, exec_horizon=50)

obs, info = env.reset()
obs['proprio']

When we do the same for the RLBench scripts, i.e. first:

from octo.data.dataset import make_single_dataset

train_dataset = make_single_dataset(
        dataset_kwargs=dict(
            name="rl_bench_dataset",
            data_dir='/abs/path/to/tensorflow_datasets',
            image_obs_keys={"primary": "image", "wrist": "wrist_image"},
            proprio_obs_key="proprio",
            language_key="language_instruction",
            # We want to avoid normalizing the gripper
            action_normalization_mask=[True, True, True, True, True, True, False],
        ),
        traj_transform_kwargs=dict(
            window_size=1,
            action_horizon=50,
        ),
        frame_transform_kwargs=dict(
            resize_size={"primary": (256, 256), "wrist": (256, 256)},
        ),
        train=True,
    )

train_data_iter = train_dataset.iterator()

x_t = next(train_data_iter)

x_t['observation']['proprio'][0]

and then:

from octo.model.octo_model import OctoModel

model = OctoModel.load_pretrained('../checkpoints/test-03-09-24-0')

use_proprio = "proprio" in model.config["model"]["observation_tokenizers"]
task_name = "place_shape_in_shape_sorter-vision-v0"
if use_proprio:
    task_name = f"{task_name}-proprio"

import gymnasium as gym
from envs.action_modes import UR5ActionMode
from envs.rl_bench_ur5_env import RLBenchUR5Env

from rlbench.utils import name_to_task_class

gym.register(
    task_name,
    entry_point=lambda: RLBenchUR5Env(
        task_class=name_to_task_class("place_shape_in_shape_sorter"),
        observation_mode='vision', render_mode="rgb_array",
        robot_setup="ur5", headless=True,
        action_mode=UR5ActionMode(), proprio=use_proprio
        )
)

env = gym.make(task_name)

from octo.utils.gym_wrappers import HistoryWrapper, NormalizeProprio, RHCWrapper

env = HistoryWrapper(env, horizon=1)
env = RHCWrapper(env, exec_horizon=50)

obs, info = env.reset(options={"variation": 0})

obs['proprio'][0]

we get different results.

So it seems that we have a mismatch between data used for training and the one used for evaluation.

mateuszwyszynski commented 2 months ago

I believe the problem is caused by the lack of normalization, i.e. we have to use NormalizeProprio after registering environment with gym:

env = gym.make(task_name)

from octo.utils.gym_wrappers import HistoryWrapper, NormalizeProprio, RHCWrapper

env = NormalizeProprio(env, model.dataset_statistics)

env = HistoryWrapper(env, horizon=1)
env = RHCWrapper(env, exec_horizon=50)

obs, info = env.reset(options={"variation": 0})

obs['proprio'][0]

After adding the line with normalization the numbers we get for the initial joint positions are approximately the same.

I believe small differences (of the order $10^{-3}$) are acceptable, because such differences are present between different episodes themselves. More precisely, if we run env.reset multiple times we will get slightly different values for the starting position. Hence my belief is that they are simply caused by some kind of randomization or numerical precision.