huggingface / lerobot

🤗 LeRobot: Making AI for Robotics more accessible with end-to-end learning
Apache License 2.0
7.39k stars 687 forks source link

Sometimes all actions are pad in the dataset #315

Open StarCycle opened 4 months ago

StarCycle commented 4 months ago

System Info

newest lerobot

Information

Reproduction

Please run the following code:

from pathlib import Path
import torch
import lerobot
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset

repo_id = "StarCycle/data1"
dataset = LeRobotDataset(repo_id, version=None)
delta_timestamps = {
    "observation.images.static": [t / dataset.fps for t in range(10)],
    "observation.images.gripper": [t / dataset.fps for t in range(10)],
    "observation.state": [t / dataset.fps for t in range(10)],
    "action.rel": [t / dataset.fps for t in range(10)],
    "timestamp": [t / dataset.fps for t in range(10)],
}
dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps, version=None)
dataloader = torch.utils.data.DataLoader(
    dataset,
    num_workers=0,
    batch_size=32,
    shuffle=True,
)
for batch in dataloader:
    print(batch['action.rel_is_pad'])
    print(batch['timestamp'])
    import pdb; pdb.set_trace()

Expected behavior

Please note that some action sequences are all padded. Does it mean only the first action in the sequence is valid?

The output will be

tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False],
        [False, False, False,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False],
        [False,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False,  True,  True],
        [False, False, False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False],
        [False, False, False, False,  True,  True,  True,  True,  True,  True],
        [False, False, False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False]])
tensor([[0.8200, 0.8200, 0.8200, 0.8200, 0.8200, 0.8200, 0.8200, 0.8200, 0.8200,
         0.8200],
        [0.3600, 0.3800, 0.4000, 0.4200, 0.4400, 0.4600, 0.4800, 0.5000, 0.5200,
         0.5400],
        [0.1800, 0.2000, 0.2200, 0.2400, 0.2600, 0.2800, 0.3000, 0.3200, 0.3400,
         0.3600],
        [0.6600, 0.6800, 0.7000, 0.7200, 0.7400, 0.7600, 0.7800, 0.8000, 0.8200,
         0.8400],
        [0.2200, 0.2400, 0.2600, 0.2800, 0.3000, 0.3200, 0.3400, 0.3600, 0.3800,
         0.4000],
        [1.2400, 1.2600, 1.2800, 1.2800, 1.2800, 1.2800, 1.2800, 1.2800, 1.2800,
         1.2800],
        [0.4200, 0.4400, 0.4600, 0.4800, 0.5000, 0.5200, 0.5400, 0.5600, 0.5800,
         0.6000],
        [0.0600, 0.0800, 0.1000, 0.1200, 0.1400, 0.1600, 0.1800, 0.2000, 0.2200,
         0.2400],
        [0.1000, 0.1200, 0.1400, 0.1600, 0.1800, 0.2000, 0.2200, 0.2400, 0.2600,
         0.2800],
        [0.7400, 0.7400, 0.7400, 0.7400, 0.7400, 0.7400, 0.7400, 0.7400, 0.7400,
         0.7400],
        [0.2800, 0.3000, 0.3200, 0.3400, 0.3600, 0.3800, 0.4000, 0.4200, 0.4400,
         0.4600],
        [1.1400, 1.1600, 1.1800, 1.2000, 1.2200, 1.2400, 1.2600, 1.2800, 1.2800,
         1.2800],
        [0.0000, 0.0200, 0.0400, 0.0600, 0.0800, 0.1000, 0.1200, 0.1400, 0.1600,
         0.1800],
        [0.4000, 0.4200, 0.4400, 0.4600, 0.4800, 0.5000, 0.5200, 0.5400, 0.5600,
         0.5800],
        [0.1400, 0.1600, 0.1800, 0.2000, 0.2200, 0.2400, 0.2600, 0.2800, 0.3000,
         0.3200],
        [0.8600, 0.8800, 0.9000, 0.9200, 0.9400, 0.9600, 0.9800, 1.0000, 1.0200,
         1.0400],
        [0.0800, 0.1000, 0.1200, 0.1400, 0.1600, 0.1800, 0.2000, 0.2200, 0.2400,
         0.2600],
        [0.1400, 0.1600, 0.1800, 0.2000, 0.2200, 0.2400, 0.2600, 0.2800, 0.3000,
         0.3200],
        [0.5000, 0.5200, 0.5400, 0.5600, 0.5800, 0.6000, 0.6200, 0.6400, 0.6600,
         0.6800],
        [0.1000, 0.1200, 0.1400, 0.1600, 0.1800, 0.2000, 0.2200, 0.2400, 0.2600,
         0.2800],
        [0.2000, 0.2200, 0.2400, 0.2600, 0.2800, 0.3000, 0.3200, 0.3400, 0.3600,
         0.3800],
        [1.2200, 1.2400, 1.2600, 1.2800, 1.2800, 1.2800, 1.2800, 1.2800, 1.2800,
         1.2800],
        [0.8800, 0.9000, 0.9200, 0.9400, 0.9600, 0.9800, 1.0000, 1.0200, 1.0400,
         1.0600],
        [0.3800, 0.4000, 0.4200, 0.4400, 0.4600, 0.4800, 0.5000, 0.5200, 0.5400,
         0.5600],
        [0.4000, 0.4200, 0.4400, 0.4600, 0.4800, 0.5000, 0.5200, 0.5400, 0.5600,
         0.5800],
        [0.5400, 0.5600, 0.5800, 0.6000, 0.6200, 0.6400, 0.6600, 0.6800, 0.7000,
         0.7200],
        [0.5400, 0.5600, 0.5800, 0.6000, 0.6200, 0.6400, 0.6600, 0.6800, 0.7000,
         0.7200],
        [0.2000, 0.2200, 0.2400, 0.2600, 0.2800, 0.3000, 0.3200, 0.3400, 0.3600,
         0.3800],
        [0.5800, 0.6000, 0.6200, 0.6400, 0.6600, 0.6800, 0.7000, 0.7200, 0.7400,
         0.7600],
        [0.2600, 0.2800, 0.3000, 0.3200, 0.3400, 0.3600, 0.3800, 0.4000, 0.4200,
         0.4400],
        [0.3400, 0.3600, 0.3800, 0.4000, 0.4200, 0.4400, 0.4600, 0.4800, 0.5000,
         0.5200],
        [0.4000, 0.4200, 0.4400, 0.4600, 0.4800, 0.5000, 0.5200, 0.5400, 0.5600,
         0.5800]])
StarCycle commented 4 months ago

The dataset is generated by the following lmdb_format.py

import gc
import shutil
from pathlib import Path

import lmdb
from pickle import loads
from torchvision.io import decode_jpeg
import numpy as np
import torch
import tqdm
from datasets import Dataset, Features, Image, Sequence, Value
from PIL import Image as PILImage

from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
from lerobot.common.datasets.utils import (
    calculate_episode_data_index,
    hf_transform_to_torch,
)
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames

VIDEO_SHARD_SIZE = 2000

def save_video(ep_idx, camera, videos_dir, imgs_array, fps, num_frames, ep_dict):
    img_key = f"observation.images.{camera}"
    tmp_imgs_dir = videos_dir / "tmp_images"
    save_images_concurrently(imgs_array, tmp_imgs_dir)

    # encode images to a mp4 video
    fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
    dir_start_ep = (ep_idx//VIDEO_SHARD_SIZE)*VIDEO_SHARD_SIZE
    dir_end_ep = (ep_idx//VIDEO_SHARD_SIZE + 1)*VIDEO_SHARD_SIZE - 1
    video_path = videos_dir / f'{dir_start_ep}-{dir_end_ep}' / fname
    encode_video_frames(tmp_imgs_dir, video_path, fps, video_codec="libx264")

    # clean temporary images directory
    shutil.rmtree(tmp_imgs_dir)

    # store the reference to the video frame
    ep_dict[img_key] = [
        {"path": f"videos/{dir_start_ep}-{dir_end_ep}/{fname}", "timestamp": i / fps} for i in range(num_frames)
    ]

def string_to_utf8_array(s, length=72):
    # Encode the string to UTF-8
    utf8_encoded = s.encode('utf-8')

    # Convert to list of integers
    utf8_array = list(utf8_encoded)

    # Ensure the array length is exactly `length`
    if len(utf8_array) < length:
        # Pad with zeros if shorter
        utf8_array += [0] * (length - len(utf8_array))
    else:
        # Trim if longer
        utf8_array = utf8_array[:length]

    return utf8_array

def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):

    env = lmdb.open(str(raw_dir), readonly=True, create=False, lock=False)

    ep_dicts = []
    with env.begin() as txn:
        dataset_len = loads(txn.get('cur_step'.encode())) + 1
        inst_list = []
        inst_token_list = []
        rgb_static_list = []
        rgb_gripper_list = []
        state_list = []
        abs_action_list = []
        rel_action_list = []
        done_list = []
        last_ep_idx = loads(txn.get(f'cur_episode_{0}'.encode()))
        ep_start = 0
        for idx in range(dataset_len):
            ep_idx = loads(txn.get(f'cur_episode_{idx}'.encode()))
            if ep_idx == last_ep_idx + 1:
                print(f'{idx}/{dataset_len}')

                done_list[-1] = True
                num_frames = idx - ep_start
                ep_start = idx

                ep_dict = {}
                save_video(ep_idx, 'static', videos_dir, rgb_static_list, fps, num_frames, ep_dict)
                save_video(ep_idx, 'gripper', videos_dir, rgb_gripper_list, fps, num_frames, ep_dict)
                ep_dict["observation.state"] = torch.stack(state_list)
                ep_dict["inst"] = torch.stack(inst_list)
                ep_dict["inst_token"] = torch.stack(inst_token_list)
                ep_dict["action.abs"] = torch.stack(abs_action_list)
                ep_dict["action.rel"] = torch.stack(rel_action_list)
                ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames)
                ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
                ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
                ep_dict["next.done"] = torch.tensor(done_list)
                ep_dicts.append(ep_dict)

                inst_list = []
                inst_token_list = []
                rgb_static_list = []
                rgb_gripper_list = []
                state_list = []
                abs_action_list = []
                rel_action_list = []
                done_list = []
                last_ep_idx = ep_idx 
            inst = torch.tensor(string_to_utf8_array(loads(txn.get(f'inst_{ep_idx}'.encode()))))
            inst_list.append(inst)
            inst_token_list.append(loads(txn.get(f'inst_token_{ep_idx}'.encode())))
            rgb_static_list.append(decode_jpeg(loads(txn.get(f'rgb_static_{idx}'.encode()))).permute(1, 2, 0).numpy())
            rgb_gripper_list.append(decode_jpeg(loads(txn.get(f'rgb_gripper_{idx}'.encode()))).permute(1, 2, 0).numpy())
            state_list.append(loads(txn.get(f'robot_obs_{idx}'.encode())))
            abs_action_list.append(loads(txn.get(f'abs_action_{idx}'.encode())))
            rel_action_list.append(loads(txn.get(f'rel_action_{idx}'.encode())))
            done_list.append(False)

        gc.collect()

    data_dict = concatenate_episodes(ep_dicts)

    total_frames = data_dict["frame_index"].shape[0]
    data_dict["index"] = torch.arange(0, total_frames, 1)
    return data_dict

def to_hf_dataset(data_dict, video) -> Dataset:
    features = {}

    keys = [key for key in data_dict if "observation.images." in key]
    for key in keys:
        if video:
            features[key] = VideoFrame()
        else:
            features[key] = Image()

    features["observation.state"] = Sequence(
        length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
    )
    if "observation.velocity" in data_dict:
        features["observation.velocity"] = Sequence(
            length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None)
        )
    if "observation.effort" in data_dict:
        features["observation.effort"] = Sequence(
            length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None)
        )
    features["action.abs"] = Sequence(
        length=data_dict["action.abs"].shape[1], feature=Value(dtype="float32", id=None)
    )
    features["action.rel"] = Sequence(
        length=data_dict["action.rel"].shape[1], feature=Value(dtype="float32", id=None)
    )
    features["inst"] = Sequence(
        length=data_dict["inst"].shape[1], feature=Value(dtype="float32", id=None)
    )
    features["inst_token"] = Sequence(
        length=data_dict["inst_token"].shape[1], feature=Value(dtype="float32", id=None)
    )
    features["episode_index"] = Value(dtype="int64", id=None)
    features["frame_index"] = Value(dtype="int64", id=None)
    features["timestamp"] = Value(dtype="float32", id=None)
    features["next.done"] = Value(dtype="bool", id=None)
    features["index"] = Value(dtype="int64", id=None)

    hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
    hf_dataset.set_transform(hf_transform_to_torch)
    return hf_dataset

def from_raw_to_lerobot_format(
    raw_dir: Path,
    videos_dir: Path,
    fps: int | None = None,
    video: bool = True,
    episodes: list[int] | None = None,
):

    if fps is None:
        fps = 50

    data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
    hf_dataset = to_hf_dataset(data_dict, video)
    episode_data_index = calculate_episode_data_index(hf_dataset)
    info = {
        "fps": fps,
        "video": video,
    }
    return hf_dataset, episode_data_index, info

Since CALVIN has too many episodes, I take the advice from @Cadene to split the mp4 files in multiple folders, and upload the folders one by one with:

from huggingface_hub import HfApi
import os

api = HfApi()
api.upload_folder(
    folder_path="./train",
    path_in_repo="train",
    repo_id="StarCycle/data1",
    repo_type="dataset",
)
api.upload_folder(
    folder_path="./meta_data",
    path_in_repo="meta_data",
    repo_id="StarCycle/data1",
    repo_type="dataset",
)
directories = directories = [f for f in os.listdir('./videos') if os.path.isdir(os.path.join(path, f))]
for dir_i in directories:
    print(dir_i)
    api.upload_folder(
        folder_path="./videos/"+dir_i,
        path_in_repo="videos/"+dir_i,
        repo_id="StarCycle/data1",
        repo_type="dataset",
    )