bytedance / GR-1

Code for "Unleashing Large-Scale Video Generative Pre-training for Visual Robot Manipulation"
Apache License 2.0
95 stars 3 forks source link

How to predict relative actions on CALVIN when ∆t = 3? #6

Closed StarCycle closed 4 months ago

StarCycle commented 4 months ago

Hi @bdrhtw @hongtaowu67,

In appendix A.4 you mentioned:

We compare the effectiveness of predicting images at different future steps (i.e. 1, 3, and 5) on CALVIN benchmark. Results are shown in Tab. 5. Pre-training is not used in this ablation. It is observed that increasing the step from 1 to 3 improves success rates. This may be because consecutive frames are very similar and predicting frames that are farther away from the current step helps the robot to understand more about the future. But the improvement saturates soon. We hypothesize that this is because the model is trained to predict local actions and predicting frames that are too far into the future may not be able to provide good guidance for immediate local action prediction.

And in appendix A.1:

Robot data is denser compared to the sampled video frame sequences in pre-training. Therefore, when finetuning on robot data, we set ∆t = 3. And we train the network to predict captured images from the static camera and the gripper camera. The input sequence length is set as 10.

The input and output shapes of the network are:

with torch.no_grad():
    prediction = self.policy(
        rgb=rgb_data,                    # [1, 10, 3, 224, 224], rgb_data[:, 1:]=0 in evaluation
        hand_rgb=hand_rgb_data,          # [1, 10, 3, 224, 224], hand_rgb_data[:, 1:]=0 in evaluation
        state=state_data,                # state_data['arm']: [1, 10, 6], state_data['gripper']: [1, 10, 2]
        language=tokenized_text,         # [1, 77]
        attention_mask=attention_mask,   # [1, 10], it's [[1,0,0,0,0,0,0,0,0,0]] in evaluation (0: this part of input is ignored)
   )
   '''
   In the output:
   prediction['obs_preds']:            [1, 10, 196, 768] if 'fwd_pred' in training_target, otherwise None
   prediction['obs_targets']:          [1, 10, 196, 768] if 'fwd_pred' in training_target, otherwise None
   prediction['obs_hand_preds']:       [1, 10, 196, 768] if fwd_pred_hand in training_target and use_hand_rgb=True, otherwise None
   prediction['obs_hand_targets']:     [1, 10, 196, 768] if fwd_pred_hand in training_target and use_hand_rgb=True, otherwise None
   prediction['arm_action_preds']:     [1, 10, 6]
   prediction['gripper_action_preds']: [1, 10, 1]
   '''

If I understand correctly, when finetuning on the CALVIN dataset, the interval between rgb_data[0, 0] and rgb[0, 1] is ∆t = 3. The interval between prediction['obs_preds'][0, 0] and prediction['obs_preds'][0, 1] is also ∆t = 3.

Since you use relative ee action space, prediction['arm_action_preds'] is actually the sum of 3 consecutive relative actions?

For example,

bdrhtw commented 4 months ago

Hi @StarCycle , the target for the rgb prediction is three timestep in the future. But the interval between two timesteps in an input sequence is one.

StarCycle commented 4 months ago

Hi @bdrhtw,

If I understand correctly, the loss function will be

rgb_mask = batch['mask'].unsqueeze(-1).unsqueeze(-1)[:, 3:]
loss['rgb_static'] = (F.mse_loss(pred['obs_preds'][:, :-3], pred['obs_targets'][:, 3:], reduction='none')*rgb_mask).mean()
loss['rgb_gripper'] = (F.mse_loss(pred['obs_hand_preds'][:, :-3], pred['obs_hand_targets'][:, 3:], reduction='none')*rgb_mask).mean()

So you use the first 7 items of pred['obs_preds'] to predict the last 7 items in pred['obs_targets']. So

The target for the rgb prediction is three timestep in the future. But the interval between two timesteps in an input sequence is one.

can be achieved.

Is my understanding correct?

Best, Star Cycle

bdrhtw commented 4 months ago

Yes, it's correct.

StarCycle commented 4 months ago

Thanks!