bytedance / GR-1

Code for "Unleashing Large-Scale Video Generative Pre-training for Visual Robot Manipulation"
Apache License 2.0
175 stars 8 forks source link

Question about random shift in training #5

Closed StarCycle closed 6 months ago

StarCycle commented 6 months ago

Hi @bdrhtw @hongtaowu67

In appendix A1 you claimed:

Random shift augmentation is applied to the images.

Did you use Random crop with padding or ResizedRandomCrop of torchvision? This means that if you use RandomCrop on a sequence of images (for example, video frames with multiple time steps), each frame will be cropped at a different random location, not at the same location.

It will make video prediction more difficult. The network needs to predict every pixel of the new frame but it does not know which area will be resized/shifted.

bdrhtw commented 6 months ago

Hi @StarCycle , we did not use torchvision but followed HULC to use random shift augmentation. Please refer to this code for more details.

StarCycle commented 6 months ago

Hi @bdrhtw,

Thank you for the response! I try the HULC shift function. It shifts images and then interpolates them. Assuming there is a image sequence with 2 images and pad=50. Before shifting, they look like:

image before shifting, step 0 image before shifting, step 0

image before shifting, step 1 image before shifting, step 1

After shifting, they look like: 图片 image after shifting to left, step 0

图片 image after shifting to right, step 1

Notice that the first image is shifted to left and the next image is shifted to right. When the policy predicts the next frame, it DOESN'T KNOW the shift direction of the next frame. It still makes video prediction difficult.

Did you modify HULC's shift augmentation code or just ignore this issue by selecting a small padding value, e.g., 10?

Best, StarCycle

StarCycle commented 6 months ago

Hi @bdrhtw,

I modify the function to accept image tensor with shape [B, T, C, H, W]. Now it may do the correct shifting.

import torch
import torch.nn.functional as F

class RandomShiftsAug(torch.nn.Module):
    def __init__(self, pad):
        super().__init__()
        self.pad = pad

    def forward(self, x):
        x = x.float()
        b, t, c, h, w = x.size()
        assert h == w
        x = x.view(b*t, c, h, w)  # reshape x to [B*T, C, H, W]
        padding = tuple([self.pad] * 4)
        x = F.pad(x, padding, "replicate")
        h_pad, w_pad = h + 2*self.pad, w + 2*self.pad  # calculate the height and width after padding
        eps = 1.0 / (h_pad)
        arange = torch.linspace(-1.0 + eps, 1.0 - eps, h_pad, device=x.device, dtype=x.dtype)[:h]
        arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2)
        base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2)
        base_grid = base_grid.unsqueeze(0).repeat(b*t, 1, 1, 1)

        shift = torch.randint(0, 2 * self.pad + 1, size=(b, 1, 1, 1, 2), device=x.device, dtype=x.dtype)
        shift = shift.repeat(1, t, 1, 1, 1)  # repeat the shift for each image in the sequence
        shift = shift.view(b*t, 1, 1, 2)  # reshape shift to match the size of base_grid
        shift *= 2.0 / (h_pad)

        grid = base_grid + shift
        output = F.grid_sample(x, grid, padding_mode="zeros", align_corners=False)
        output = output.view(b, t, c, h, w)  # reshape output back to [B, T, C, H, W]
        return output

图片

Please let me know if my understanding is wrong!

bdrhtw commented 6 months ago

If random shift augmentation is applied to frame level, it would result in different shifts across frames. And definitely this would make video prediction harder. If it is applied to sequence level, all frames in the sequence would be shifted identically. We tested both frame-level and sequence-level random shift augmentation and found that sequence-level augmentation performs better though the advantage is not large.

StarCycle commented 6 months ago

Thank you for the great response!