Closed StarCycle closed 6 months ago
Hi @StarCycle , we did not use torchvision but followed HULC to use random shift augmentation. Please refer to this code for more details.
Hi @bdrhtw,
Thank you for the response! I try the HULC shift function. It shifts images and then interpolates them. Assuming there is a image sequence with 2 images and pad=50. Before shifting, they look like:
image before shifting, step 0
image before shifting, step 1
After shifting, they look like: image after shifting to left, step 0
image after shifting to right, step 1
Notice that the first image is shifted to left and the next image is shifted to right. When the policy predicts the next frame, it DOESN'T KNOW the shift direction of the next frame. It still makes video prediction difficult.
Did you modify HULC's shift augmentation code or just ignore this issue by selecting a small padding value, e.g., 10?
Best, StarCycle
Hi @bdrhtw,
I modify the function to accept image tensor with shape [B, T, C, H, W]. Now it may do the correct shifting.
import torch
import torch.nn.functional as F
class RandomShiftsAug(torch.nn.Module):
def __init__(self, pad):
super().__init__()
self.pad = pad
def forward(self, x):
x = x.float()
b, t, c, h, w = x.size()
assert h == w
x = x.view(b*t, c, h, w) # reshape x to [B*T, C, H, W]
padding = tuple([self.pad] * 4)
x = F.pad(x, padding, "replicate")
h_pad, w_pad = h + 2*self.pad, w + 2*self.pad # calculate the height and width after padding
eps = 1.0 / (h_pad)
arange = torch.linspace(-1.0 + eps, 1.0 - eps, h_pad, device=x.device, dtype=x.dtype)[:h]
arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2)
base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2)
base_grid = base_grid.unsqueeze(0).repeat(b*t, 1, 1, 1)
shift = torch.randint(0, 2 * self.pad + 1, size=(b, 1, 1, 1, 2), device=x.device, dtype=x.dtype)
shift = shift.repeat(1, t, 1, 1, 1) # repeat the shift for each image in the sequence
shift = shift.view(b*t, 1, 1, 2) # reshape shift to match the size of base_grid
shift *= 2.0 / (h_pad)
grid = base_grid + shift
output = F.grid_sample(x, grid, padding_mode="zeros", align_corners=False)
output = output.view(b, t, c, h, w) # reshape output back to [B, T, C, H, W]
return output
Please let me know if my understanding is wrong!
If random shift augmentation is applied to frame level, it would result in different shifts across frames. And definitely this would make video prediction harder. If it is applied to sequence level, all frames in the sequence would be shifted identically. We tested both frame-level and sequence-level random shift augmentation and found that sequence-level augmentation performs better though the advantage is not large.
Thank you for the great response!
Hi @bdrhtw @hongtaowu67
In appendix A1 you claimed:
Did you use Random crop with padding or ResizedRandomCrop of torchvision? This means that if you use RandomCrop on a sequence of images (for example, video frames with multiple time steps), each frame will be cropped at a different random location, not at the same location.
It will make video prediction more difficult. The network needs to predict every pixel of the new frame but it does not know which area will be resized/shifted.