bytedance / GR-1

Code for "Unleashing Large-Scale Video Generative Pre-training for Visual Robot Manipulation"
Apache License 2.0
95 stars 3 forks source link

Coefficients in the loss function and random shifting? #7

Closed StarCycle closed 3 months ago

StarCycle commented 4 months ago

Hi @bdrhtw @hongtaowu67,

In the paper, it seems that the coefficient for every term is 1: 图片

But when I train the policy, it seems that using a higher coefficient on L_arm improves performance.

My loss function looks like this. If it's wrong please let me know...

def masked_loss(pred, target, mask, skip_frame=0, loss_func=F.mse_loss):
    if skip_frame == 0:
        new_pred = pred
    else:
        new_pred = pred[:, :-skip_frame]
    new_target = target[:, skip_frame:]
    new_mask = mask[:, skip_frame:]
    data_shape, mask_shape = new_pred.shape, new_mask.shape
    loss = loss_func(new_pred, new_target, reduction='none')
    for _ in range(len(data_shape) - len(mask_shape)):
        new_mask = new_mask.unsqueeze(-1)
    loss = (loss*new_mask).sum() / new_mask.sum() / math.prod(data_shape[len(mask_shape):])
    return loss

loss = {}
loss['rgb_static'] = masked_loss(pred['obs_preds'], pred['obs_targets'], batch['mask'], cfg['skip_frame'], F.mse_loss)
loss['rgb_gripper'] = masked_loss(pred['obs_hand_preds'], pred['obs_hand_targets'], batch['mask'], cfg['skip_frame'], F.mse_loss)
loss['action_arm'] = masked_loss(pred[ 'arm_action_preds'], batch['actions'][:, :, :6], batch['mask'], 0, F.smooth_l1_loss)
loss['action_gripper'] = masked_loss(F.sigmoid(pred['gripper_action_preds']), batch['actions'][:, :, -1:], batch['mask'], 0, F.binary_cross_entropy)
total_loss = loss['rgb_static'] + loss['rgb_gripper'] +100*loss['action_arm'] + loss['action_gripper'] 

What coefficient do you use? By the way, in random shifting #5, what padding value do you use?

class RandomShiftsAug(torch.nn.Module):
    def __init__(self, pad):
        super().__init__()
        self.pad = pad # what's the pad value here?

    def forward(self, x):
        x = x.float()
        b, t, c, h, w = x.size()
        assert h == w
        x = x.view(b*t, c, h, w)  # reshape x to [B*T, C, H, W]
        padding = tuple([self.pad] * 4)
        ...
bdrhtw commented 3 months ago

Hi @StarCycle , the coefficients for arm loss, gripper loss, and video loss is 1, 0.01, 0.01. For the padding value, we use 10 for static rgbs and 4 for hand rgbs.

StarCycle commented 3 months ago

Hi @bdrhtw , great thanks for this answer!