huggingface / lerobot

🤗 LeRobot: Making AI for Robotics more accessible with end-to-end learning
Apache License 2.0
7.62k stars 726 forks source link

feat: enable to use multiple rgb encoders per camera in diffusion policy #484

Closed HiroIshida closed 3 weeks ago

HiroIshida commented 1 month ago

What this does

This PR enables to configure multiple rgb encoders for each cameras, which is done in the original diffusion policy paper. Not for sure in general setting, but at least in my task, this seems to increase the success rate of deployment slightly. Current diffusion policy shared the single RGB encoder for all cameras. This fixes https://github.com/huggingface/lerobot/issues/483

How it was tested

a) regression test

Prepare the following script and check that the behavior of the policy does not change before and the after the PR by comparing the hash value of the output.

import yaml
import torch
from hashlib import sha256
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy, DiffusionConfig

def run(camera_names, rgb_encoder_per_camera):
    torch.manual_seed(0)

    resol = 56
    input_shapes = {"observation.state": [6]}
    for name in camera_names:
        input_shapes[f"observation.image.{name}"] = [3, resol, resol]
    output_shapes = {"action": [6]}

    # determine normalizer
    normalization_mode = {"observation.state": "min_max"}
    for name in camera_names:
        normalization_mode[f"observation.image.{name}"] = "mean_std"

    # load stats
    with open('stats-diffusion.yaml', 'r') as f:
        stats = yaml.load(f, Loader=yaml.FullLoader)
    effective_keys = list(output_shapes.keys()) + list(input_shapes.keys()) + ["episode_index", "frame_indx", "index", "next.done",  "timestamp"]
    effective_key_set = set(effective_keys)
    for key, value in stats.items():
        if key not in effective_key_set:
            continue
        inner_dict = {}
        for key_inner, value_inner in value.items():
            inner_dict[key_inner] = torch.tensor(value_inner)
        stats[key] = inner_dict

    # create policy
    conf = DiffusionConfig(input_shapes=input_shapes, output_shapes=output_shapes, input_normalization_modes=normalization_mode, crop_shape=None, num_inference_steps=5)
    if rgb_encoder_per_camera:
        conf.rgb_encoder_per_camera = True
    policy = DiffusionPolicy(conf, dataset_stats=stats)

    # inference
    observation = {"observation.state": torch.rand(1, 6)}
    for name in camera_names:
        observation[f"observation.image.{name}"] = torch.rand(1, 3, resol, resol)

    action = policy.select_action(observation)

    # compute hash of action
    hash_value = sha256(action.numpy().tobytes()).hexdigest()
    print("hash value of output: ", hash_value)

if __name__ == "__main__":
    # case 1
    print(f"test with only gripper camera")
    camera_names = ["gripper"]
    run(camera_names, False)

    # case 2
    print(f"test with gripper and webcam cameras where rgb_encoder_per_camera is not set")
    camera_names = ["gripper", "webcam"]
    run(camera_names, False)

    # case 3
    print(f"test with gripper and webcam cameras where rgb_encoder_per_camera is set")
    has_rgb_encoder_per_camera = "rgb_encoder_per_camera" in DiffusionConfig.__annotations__
    if not has_rgb_encoder_per_camera:
        print("DiffusionConfig does not have attribute rgb_encoder_per_camera")
    else:
        camera_names = ["gripper", "webcam"]
        run(camera_names, True)

Before this PR

integral@umejuice:~/python/snippets/python/ext_examples/lerobot$ python3 diffusion_policy_multiple_cameras.py 
test with only gripper camera
hash value of output:  b3f55809cad72b41e484ec589d9ace5acf1f16439c4b0d71817dc6cbb41d0f4f
test with gripper and webcam cameras where rgb_encoder_per_camera is not set
hash value of output:  d08244f48c83964d90a3224790c9d1f9fc20d58c870420c028b3be46a0ff3b22
test with gripper and webcam cameras where rgb_encoder_per_camera is set
DiffusionConfig does not have attribute rgb_encoder_per_camera

After this PR

integral@umejuice:~/python/snippets/python/ext_examples/lerobot$ python3 diffusion_policy_multiple_cameras.py 
test with only gripper camera
hash value of output:  b3f55809cad72b41e484ec589d9ace5acf1f16439c4b0d71817dc6cbb41d0f4f
test with gripper and webcam cameras where rgb_encoder_per_camera is not set
hash value of output:  d08244f48c83964d90a3224790c9d1f9fc20d58c870420c028b3be46a0ff3b22
test with gripper and webcam cameras where rgb_encoder_per_camera is set
hash value of output:  a636c3ad1313eda2504b008d92eb823d58acc6344582bb4922d73892371e5da0

You can see the match of hash values.

b) validity check

The above test include the case (see case3 ), which is run only when the DiffusionConfig has rgb_encoder_per_camera attribute, so activate only after this PR. And as you can see, there is no error occured and output the action.

c) tested on real robot

I trained the diffusion policy with the rgb_encoder_per_camera=True , and checked that policy worked well. I cannot say concrete stuff but the performance is slightly better than the original.

How to checkout & try? (for the reviewer)

Try the above script and please check the hash value before and after this PR. Before the run, you need to save the following yaml file as 'stats-diffusion.yaml.

action:
  max:
  - 119.3543701171875
  - 49.64094543457031
  - 111.9273910522461
  - 7.118511199951172
  - 106.76649475097656
  - 26.715518951416016
  mean:
  - 94.45906829833984
  - 18.427291870117188
  - 73.95457458496094
  - 1.212478518486023
  - 79.41802978515625
  - 6.242947578430176
  min:
  - 77.6657943725586
  - -29.830978393554688
  - 39.58110809326172
  - -8.624532699584961
  - 44.424076080322266
  - -18.17777442932129
  std:
  - 8.80908489227295
  - 16.74974822998047
  - 14.223931312561035
  - 1.9911779165267944
  - 10.933554649353027
  - 9.21682357788086
episode_index:
  max:
  - 69.0
  mean:
  - 34.674591064453125
  min:
  - 0.0
  std:
  - 20.254430770874023
frame_index:
  max:
  - 176.0
  mean:
  - 61.95341873168945
  min:
  - 0.0
  std:
  - 37.45309829711914
index:
  max:
  - 8585.0
  mean:
  - 4292.5
  min:
  - 0.0
  std:
  - 2478.564453125
next.done:
  max:
  - 1.0
  mean:
  - 0.008152804337441921
  min:
  - 0.0
  std:
  - 0.08992405980825424
observation.image.gripper:
  max:
  - - - 1.0
  - - - 1.0
  - - - 1.0
  mean:
  - - - 0.3928315341472626
  - - - 0.39230185747146606
  - - - 0.360833078622818
  min:
  - - - 0.0
  - - - 0.0
  - - - 0.0
  std:
  - - - 0.2355380803346634
  - - - 0.23801472783088684
  - - - 0.2544608414173126
observation.image.webcam:
  max:
  - - - 1.0
  - - - 1.0
  - - - 1.0
  mean:
  - - - 0.405765175819397
  - - - 0.6512684226036072
  - - - 0.6204251646995544
  min:
  - - - 0.0
  - - - 0.0
  - - - 0.0
  std:
  - - - 0.1711459755897522
  - - - 0.1499525010585785
  - - - 0.1772834062576294
observation.state:
  max:
  - 113.90892791748047
  - 49.61915588378906
  - 111.92845916748047
  - 7.113752841949463
  - 106.78263092041016
  - 26.70062255859375
  mean:
  - 95.00767517089844
  - 15.08882999420166
  - 76.28932189941406
  - 1.3102922439575195
  - 80.53984832763672
  - 6.705284595489502
  min:
  - 77.66748809814453
  - -29.83199119567871
  - 39.72526550292969
  - -5.572042942047119
  - 44.40793991088867
  - -18.174049377441406
  std:
  - 8.437381744384766
  - 18.46141242980957
  - 14.492023468017578
  - 1.9003233909606934
  - 10.978218078613281
  - 8.922645568847656
timestamp:
  max:
  - 17.600000381469727
  mean:
  - 6.195342540740967
  min:
  - 0.0
  std:
  - 3.745309591293335
HiroIshida commented 1 month ago

@alexander-soare Could you review this PR? :)

HiroIshida commented 4 weeks ago

@alexander-soare Thank you for your review! I fixed the name and applied your suggestions. also, modify the test code and checked that it works https://gist.github.com/HiroIshida/5d0abbec0fdd0d140c5262eb1301c98e

danaaubakirova commented 4 weeks ago

@HiroIshida Thanks a lot for this PR! The quality check Is not passing. Otherwise, LGTM. I will approve after the issue is fixed:)

HiroIshida commented 3 weeks ago

@danaaubakirova Thanks. I applied the formatter.