bytedance / GR-1

Code for "Unleashing Large-Scale Video Generative Pre-training for Visual Robot Manipulation"
Apache License 2.0
190 stars 9 forks source link

Put my reproduction experience here & GR-1复现交流群 #4

Open StarCycle opened 7 months ago

StarCycle commented 7 months ago

GR-1复现交流群,作者大大加群的话会将群主权限移交,有复现进展会在此贴更新,二维码尽量保持更新

这种神仙repo大家点颗星再走

图片

StarCycle commented 7 months ago

[2024.4.19] Successfully run the policy

图片

Recommendations:

  1. Use python 3.8. Since you install calvin in a conda environment with python 3.8, please install the GR-1 codebase directly in the conda environment
  2. If you encounter problems installing pyhash (during installing calvin), you might execute
    pip install setuptools==57.5.0
  3. Generally speaking, you can install calvin with the following commands:
    source activate
    conda create -n calvin_venv python=3.8
    conda activate calvin_venv
    pip install setuptools==57.5.0
    git clone --recurse-submodules https://github.com/mees/calvin.git
    export CALVIN_ROOT=$(pwd)/calvin
    cd calvin
    cd calvin_env; git checkout main
    cd ..
    sh ./install.sh; cd ..

    Remember to use the latest calvin_env module, which fixes bugs of turn_off_led. See this for detail.

  4. You dont have to download the original CALVIN dataset for evaluation of the policy. Download and unzip fake_dataset.zip as a fake dataset is sufficient.
  5. After executing sh install.sh in the GR-1 folder, if you get some error reported from GPT2, note that this repo relies on transformers==4.5.1, which can be easily install with python3.8
    pip install transformers==4.5.1
StarCycle commented 7 months ago

[2024.4.20] Evaluation result of ABC->D

Results for Epoch None:
Average successful sequence length: 3.256
Success rates for i instructions in a row:
1: 88.0%
2: 75.2%
3: 63.9%
4: 54.5%
5: 44.0%
rotate_blue_block_right: 57 / 68 |  SR: 83.8%
move_slider_right: 229 / 238 |  SR: 96.2%
lift_red_block_slider: 107 / 111 |  SR: 96.4%
place_in_slider: 261 / 301 |  SR: 86.7%
turn_off_led: 159 / 162 |  SR: 98.1%
push_into_drawer: 78 / 100 |  SR: 78.0%
lift_blue_block_drawer: 14 / 14 |  SR: 100.0%
close_drawer: 166 / 166 |  SR: 100.0%
lift_pink_block_slider: 106 / 114 |  SR: 93.0%
open_drawer: 301 / 301 |  SR: 100.0%
rotate_red_block_right: 60 / 66 |  SR: 90.9%
lift_red_block_table: 138 / 146 |  SR: 94.5%
lift_pink_block_table: 120 / 136 |  SR: 88.2%
turn_on_led: 167 / 173 |  SR: 96.5%
push_red_block_left: 59 / 73 |  SR: 80.8%
lift_blue_block_table: 125 / 129 |  SR: 96.9%
rotate_blue_block_left: 53 / 53 |  SR: 100.0%
place_in_drawer: 136 / 141 |  SR: 96.5%
move_slider_left: 202 / 204 |  SR: 99.0%
rotate_red_block_left: 57 / 58 |  SR: 98.3%
stack_block: 114 / 153 |  SR: 74.5%
push_pink_block_left: 61 / 65 |  SR: 93.8%
lift_blue_block_slider: 106 / 111 |  SR: 95.5%
rotate_pink_block_right: 57 / 59 |  SR: 96.6%
unstack_block: 53 / 53 |  SR: 100.0%
push_red_block_right: 31 / 61 |  SR: 50.8%
rotate_pink_block_left: 44 / 45 |  SR: 97.8%
push_blue_block_left: 49 / 57 |  SR: 86.0%
lift_pink_block_drawer: 8 / 9 |  SR: 88.9%
push_pink_block_right: 34 / 57 |  SR: 59.6%
push_blue_block_right: 27 / 59 |  SR: 45.8%
turn_on_lightbulb: 43 / 172 |  SR: 25.0%
turn_off_lightbulb: 18 / 145 |  SR: 12.4%
lift_red_block_drawer: 16 / 16 |  SR: 100.0%

Best model: epoch None with average sequences length of 3.256
disconnecting id 0 from server
Destroy EGL OpenGL window.

Note that in their original paper, the performance is: 图片

The success rate in my evaluation is even higher! I fixed the bug of turn_off_led while @bdrhtw @hongtaowu67 did not find this problem?

On the ABC->D leaderboard, it`s very close to the current SOTA 3D diffuser actor without using depth information: 图片

The log files: result.txt success_rate.txt

StarCycle commented 7 months ago

[2024.4.20] Evaluation result of ABCD->D

Results for Epoch None:
Average successful sequence length: 4.302
Success rates for i instructions in a row:
1: 95.2%
2: 91.1%
3: 86.5%
4: 81.7%
5: 75.7%
rotate_blue_block_right: 72 / 77 |  SR: 93.5%
move_slider_right: 274 / 274 |  SR: 100.0%
lift_red_block_slider: 126 / 137 |  SR: 92.0%
place_in_slider: 338 / 357 |  SR: 94.7%
turn_off_lightbulb: 150 / 150 |  SR: 100.0%
turn_off_led: 168 / 168 |  SR: 100.0%
push_into_drawer: 100 / 125 |  SR: 80.0%
lift_blue_block_drawer: 19 / 19 |  SR: 100.0%
close_drawer: 218 / 218 |  SR: 100.0%
lift_pink_block_slider: 140 / 141 |  SR: 99.3%
open_drawer: 359 / 360 |  SR: 99.7%
rotate_red_block_right: 74 / 75 |  SR: 98.7%
lift_red_block_table: 175 / 177 |  SR: 98.9%
lift_pink_block_table: 160 / 170 |  SR: 94.1%
move_slider_left: 252 / 252 |  SR: 100.0%
turn_on_lightbulb: 182 / 182 |  SR: 100.0%
rotate_blue_block_left: 64 / 68 |  SR: 94.1%
push_blue_block_left: 58 / 69 |  SR: 84.1%
turn_on_led: 176 / 179 |  SR: 98.3%
stack_block: 170 / 201 |  SR: 84.6%
push_pink_block_right: 47 / 67 |  SR: 70.1%
push_red_block_left: 65 / 78 |  SR: 83.3%
lift_blue_block_table: 173 / 178 |  SR: 97.2%
place_in_drawer: 173 / 175 |  SR: 98.9%
rotate_red_block_left: 62 / 65 |  SR: 95.4%
push_pink_block_left: 65 / 77 |  SR: 84.4%
lift_blue_block_slider: 123 / 132 |  SR: 93.2%
push_red_block_right: 44 / 72 |  SR: 61.1%
lift_pink_block_drawer: 15 / 15 |  SR: 100.0%
rotate_pink_block_right: 70 / 71 |  SR: 98.6%
unstack_block: 67 / 69 |  SR: 97.1%
push_blue_block_right: 51 / 72 |  SR: 70.8%
rotate_pink_block_left: 54 / 57 |  SR: 94.7%
lift_red_block_drawer: 18 / 18 |  SR: 100.0%

Best model: epoch None with average sequences length of 4.302
disconnecting id 0 from server
Destroy EGL OpenGL window.

Again, the performance is slightly higher than the original paper reported.

It`s the SOTA of the ABCD->D leaderboard. Better than the heavy Roboflamingo. 图片

The log files: result(1).txt success_rate(1).txt

StarCycle commented 7 months ago

[2024.4.20] Network Details

These hyperparameters are set in this config file.

图片

Trainable parameters: 45,988,039 Total parameters: 283,139,272 (Since CLIP vision encoder is loaded but not used, the actual parameters number is 195,290,056)

Parameter number of the MAE (frozen): 85,798,656, ViT-B/32 version Parameter number of the CLIP text encoder (frozen): 37,828,608, ViT-Base

Parameter number of the perceiver resampler: 18,897,408 resampler arch: transformer with 2 layers and cross attention:

(perceiver_resampler): PerceiverResampler(
    (layers): ModuleList(
      (0): ModuleList(
        (0): PerceiverAttention(
          (norm_media): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (norm_latents): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (to_q): Linear(in_features=768, out_features=512, bias=False)
          (to_kv): Linear(in_features=768, out_features=1024, bias=False)
          (to_out): Linear(in_features=512, out_features=768, bias=False)
        )
        (1): ...

Parameter number of core transformer (GPT2): 21,294,720 GPT2 arch: 12 layers, forward dim=384

(transformer): GPT2Model(
    (wte): Embedding(1, 384)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): Block(
        (ln_1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
     (1~11): ...

Parameter number of the image decoder: 3,548,928 (transformer) + 147,840 (embedding) arch: 2 layer transormer

(decoder_embed): Linear(in_features=384, out_features=384, bias=True)
  (decoder_blocks): ModuleList(
    (0): Block(
      (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=384, out_features=1152, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=384, out_features=1536, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=1536, out_features=384, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (1): ...

Parameters of other layers like the action projection layer are not discussed here.

Image preprocess

input_size = (224, 224)
rgb_mean = (0.485, 0.456, 0.406)
rgb_std = (0.229, 0.224, 0.225)
self.preprocess = T.Compose([
    T.Resize(input_size, interpolation=Image.BICUBIC),
    T.Normalize(rgb_mean, rgb_std)])

As for random shift of images, please refer to #5.

As for skip-frame in video & action prediction, please refer to #6.

For target image they do:

p = self.patch_size
h_p = h // p
w_p = w // p
rgb = rgb.reshape(shape=(batch_size, sequence_length, 3, h_p, p, w_p, p)) 
obs_targets = rgb.permute(0, 1, 3, 5, 4, 6, 2)
obs_targets = obs_targets.reshape(shape=(batch_size, sequence_length, h_p * w_p, (p**2) * 3))  # (b, l, n_patches, p*p*3)
if not self.without_norm_pixel_loss:
    # norm the target 
    obs_targets = (obs_targets - obs_targets.mean(dim=-1, keepdim=True)
       ) / (obs_targets.var(dim=-1, unbiased=True, keepdim=True).sqrt() + 1e-6)

Language preprocess

# Embed language
lang_embeddings = self.model_clip.encode_text(language)
lang_embeddings = lang_embeddings / (lang_embeddings.norm(dim=1, keepdim=True) + 1e-6) # normalization 
lang_embeddings = self.embed_lang(lang_embeddings.float())  # (b, h)
StarCycle commented 7 months ago

[2024.4.20] APIs

The GR1 policy, whose hyperparameters are set in this config file.

class GR1(nn.Module):
    def __init__(
            self,
            model_clip, # nn.module
            model_mae, # nn.module
            state_dim, # config['state_dim']
            act_dim, # config['act_dim']
            hidden_size, # config['embed_dim']
            sequence_length, # config['seq_len']
            training_target, # List: ['act_pred', 'fwd_pred', 'fwd_pred_hand'], remove some of them if needed
            img_feat_dim, # config['img_feat_dim']
            patch_feat_dim, # config['patch_feat_dim']
            lang_feat_dim, # config['lang_feat_dim']
            resampler_params, # Dict from config: {'depth': x, 'dim_head': x, 'heads': x, 'num_latents': x, 'num_media_embeds': x}
            without_norm_pixel_loss=False, # whether normalize the target image or not
            use_hand_rgb=True, # whether use hand camera image or not
            **kwargs
    ):

In calvin_evaluation.py:

with torch.no_grad():
    prediction = self.policy(
        rgb=rgb_data,                    # [1, 10, 3, 224, 224], rgb_data[:, 1:]=0 in evaluation
        hand_rgb=hand_rgb_data,          # [1, 10, 3, 224, 224], hand_rgb_data[:, 1:]=0 in evaluation
        state=state_data,                # state_data['arm']: [1, 10, 6], state_data['gripper']: [1, 10, 2]
        language=tokenized_text,         # [1, 77]
        attention_mask=attention_mask,   # [1, 10], it's [[1,0,0,0,0,0,0,0,0,0]] in evaluation (0: this part of input is ignored)
   )
   '''
   In the output:
   prediction['obs_preds']:            [1, 10, 196, 768] if 'fwd_pred' in training_target, otherwise None
   prediction['obs_targets']:          [1, 10, 196, 768] if 'fwd_pred' in training_target, otherwise None
   prediction['obs_hand_preds']:       [1, 10, 196, 768] if fwd_pred_hand in training_target and use_hand_rgb=True, otherwise None
   prediction['obs_hand_targets']:     [1, 10, 196, 768] if fwd_pred_hand in training_target and use_hand_rgb=True, otherwise None
   prediction['arm_action_preds']:     [1, 10, 6]
   prediction['gripper_action_preds']: [1, 10, 1]
   '''

prediction['obs_targets'] and prediction['obs_hand_targets'] do not need gradient. 196 of the shape is number of patch, and 768 is patch_size*patch_size*3 (16*16*3). You can recover the predicted video frames by

from einops import rearrange
additional unnormalization if without_norm_pixel_loss=False
prediction['obs_preds'] = rearrange(prediction['obs_preds'], 'b s (hp wp) (p1 p2 c) -> b s c (hp p1) (wp p2)', hp=14, wp=14, p1=16, p2=16, c=3)
prediction['obs_preds'] *= torch.tensor([0.229, 0.224, 0.225]).cuda().view(1, 1, -1, 1, 1)
prediction['obs_preds'] += torch.tensor([0.485, 0.456, 0.406]).cuda().view(1, 1, -1, 1, 1)
StarCycle commented 7 months ago

[2024.4.20] Some internal tensor shapes

obs_embeddings from MAE: (batch size*seq len, 768) patch_embeddings from MAE: (batch size*seq len, 196, 768) patch_embeddings from resampler: (batch size*seq len, 1, 9, 768) (heavily compress the data) lang_embeddings from CLIP text encoder: (batch size*seq len, 512)

Their last dimensions are projected to 384 before sent into GPT2. GPT2 input shape: (batch size, 430, 384)

After the light-weight image decoder, the last dimention of obs_pred is projected from 384 to 768.

StarCycle commented 7 months ago

[2024.4.21] Video Generation Test

Theortically, GR-1 can generate a short video in a single forward pass because of the learned [OBS] tokens. It can also predicts a short tranjectory (or you can say, an action chunk) in a pass.

The reconstruction quality is poor but the following video still looks like the desk of CALVIN benchmark. test

The original desk: calvin_3

My code may be wrong, I hope @bdrhtw @hongtaowu67 release the official video generation script.

bdrhtw commented 7 months ago

Hi @StarCycle ,

thanks for your attention to this work.

We are not aware of the bug in the turn_off_led task. Can you kindly share more details on this bug?

As for video prediction, we follow MAE to normalize the images by batch in training: https://github.com/bytedance/GR-1/blob/main/models/gr1.py#L193. To generate normal images, without_norm_pixel_loss needs to be set True in training. The training loss for video prediction is the L2 loss--the reconstructed image can be blurred. Nevertheless, it can still be a strong regularization to help the network learn actions more effectively.

StarCycle commented 7 months ago

Hi @bdrhtw,

I found the turn_off_led bug in the README of 3D Diffuser Actor. They shared a link of this post. Since GR-1 is earlier than their work, I guess you have the same bug in your calvin environment. 图片

For video generation, I see in your evaluation code without_norm_pixel_loss=False (in the config file). If I understand correctly:

  1. You set without_norm_pixel_loss=False to train the current checkpoint. The reconstucted images looks not good in this case but it helps to learn the action output more effectively.
  2. You also trained another checkpoint that without_norm_pixel_loss=False. The reconstucted images looks better but still blurred due to the L2 loss.

btw, it's very helpful if you can enter the wechat group :)

bdrhtw commented 7 months ago

Thanks for sharing the bug fix.

Yes. we trained two checkpoints, i.e., one with without_norm_pixel_loss=False and another with it set True.

StarCycle commented 6 months ago

[2024.5.17] For the full training code, please refer to #8

snitchyang commented 4 months ago

可以更新下群二维码吗

StarCycle commented 4 months ago

@snitchyang 您好,已经更新,我等会也更新下我repo的代码

negativegluon commented 2 months ago

请问可以更新一下群二维码吗QwQ

StarCycle commented 2 months ago

@negativegluon 请用这个 图片

wlxing1901 commented 1 month ago

@StarCycle 请问可以载更新一下群二维码吗, 感谢!

StarCycle commented 1 month ago

@wlxing1901 好的,已经更新,欢迎加群: 图片

SpaceLearner commented 1 month ago

求加群

StarCycle commented 1 month ago

@SpaceLearner 请您用这个二维码: 图片

AoqunJin commented 1 month ago

求加群 @StarCycle 🤗

StarCycle commented 1 month ago

@AoqunJin

图片

Yes I have updated it!

Jacksonfei commented 1 month ago

@StarCycle in your terminal output (https://github.com/bytedance/GR-1/issues/4#issuecomment-2066639543), it shows that you use 3 EGL device choice. However, in my output, it shows that "Couldn't find correct EGL device. Setting EGL_VISIBLE_DEVICE=0. When using DDP with many GPUs this can lead to OOM errors. Did you install PyBullet correctly? Please refer to calvin env README" It seems that i could only use one egl device, but i actually own 8 GPU, how could i solve this problem? Thks a lot !!!

my outputs are:

Global seed set to 0 loading state dict: logs/snapshot_ABC.pt... pybullet build time: Nov 28 2023 23:52:03 Couldn't find correct EGL device. Setting EGL_VISIBLE_DEVICE=0. When using DDP with many GPUs this can lead to OOM errors. Did you install PyBullet correctly? Please refer to calvin env README argv[0]=--width=200 argv[1]=--height=200 _EGL device choice: 0 of 1 (from EGL_VISIBLE_DEVICES)_ Loaded EGL 1.5 after reload. GL_VENDOR=Mesa GL_RENDERER=llvmpipe (LLVM 15.0.7, 256 bits) GL_VERSION=3.3 (Core Profile) Mesa 23.2.1-1ubuntu3.1~22.04.2

StarCycle commented 1 month ago

Hello @Jacksonfei,

The EGL problem comes from calvin...and it depends on the machine you use...

One option is to run conda install -c conda-forge gcc=12.1, but it may not 100% succeed.

Would you like to add my Wechat (StarRingSpace)? In the worst case, I can try to set up a docker image for you. You need to train on your machine, and evaluate in the docker...According to my experience, you will finally solve the EGL problem if you switch a docker image or you carefully configure your system...

StarCycle commented 1 month ago

I also list many installation errors and solutions here: https://github.com/EDiRobotics/mimictest#possible-installation-problems

Tigerdwgth commented 2 weeks ago

神,求个新二维码

StarCycle commented 2 weeks ago

@Tigerdwgth Hi please check this qr code: 图片

1786707378 commented 3 days ago

求加群 @StarCycle 🤗

StarCycle commented 3 days ago

@1786707378 请您用这个二维码: 图片