vt-vl-lab / FGVC

[ECCV 2020] Flow-edge Guided Video Completion
Other
1.55k stars 263 forks source link

grid_sample() error caused by shape miss-match between feature maps and grid object #23

Open eduardathome opened 3 years ago

eduardathome commented 3 years ago

Hi, I ran into this error, I tried to work it out myself and I found the cause but no real solution yet, I'm willing to find a reliable one and share it, if interested. Below I wrote a report detailing what I found:

I. Error and reproductibility:

Error text:

  File "e:/Work/FGVC/tool/video_completion.py", line 120, in calculate_flow
    _, flow = model(image1, image2, iters=20, test_mode=True)
  File "C:\Users\eduard\anaconda3\envs\torch\lib\site-packages\torch\nn\modules\module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "e:\Work\FGVC\RAFT\raft.py", line 127, in forward
    corr = corr_fn(coords1) # index correlation volume
  File "e:\Work\FGVC\RAFT\corr.py", line 51, in __call__
    corr = bilinear_sampler(corr, coords_lvl)
  File "e:\Work\FGVC\RAFT\utils\utils.py", line 66, in bilinear_sampler
    img = F.grid_sample(img, grid, align_corners=True)
  File "C:\Users\eduard\anaconda3\envs\torch\lib\site-packages\torch\nn\functional.py", line 3390, in grid_sample
    return torch.grid_sampler(input, grid, mode_enum, padding_mode_enum, align_corners)
RuntimeError: grid_sampler(): expected grid and input to have same batch size, but got input with sizes [2850, 1, 38, 75] and grid with sizes [2775, 9, 9, 2]

Run into error while running a cloned code from https://github.com/vt-vl-lab/FGVC Version 1.0 with the following command

video_completion.py --mode object_removal --path E:/Work/video_segmentation/images_/ --path_mask E:/Work/video_segmentation/masks_/ --outroot ../result

On a set of 10 images/masks pairs, with the shapes [3, 300, 600]

Error origin:

The functiongrid_sampler(...) is used by bilinear_sampler(...) having the arguments corr and coords_lvl with different shapes [2850, 1, 38, 75] and [2775, 9, 9, 2]

corr = bilinear_sampler(corr, coords_lvl)

which breaks grid_sampler(...) because of the different dimmensions

In the next 3 chapters I follow both objects trail to find why they have different shapes


I. coords_lvl : CorrBlock.call(self, coords) -> cendroid_lvl

In raft.py the initialize_flow() function computes the size of the grid as being (1, 37, 75) from image with shape (1, 300, 600) because H/8 = 37.5 and H//8 = 37

    def initialize_flow(self, img):
        """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
        N, C, H, W = img.shape

        coords0 = coords_grid(N, H//8, W//8).to(img.device)
        coords1 = coords_grid(N, H//8, W//8).to(img.device)

        # optical flow computed as difference: flow = coords1 - coords0
        return coords0, coords1

This propgates to corr_fn() -> CorrBlock.call() that receive the coordinates as being : torch.Size([1, 2, 37, 75])

        coords0, coords1 = self.initialize_flow(image1)
        #...
        corr = corr_fn(coords1) # index correlation volume
        corr = CorrBlock.corr(fmap1, fmap2)
        batch, h1, w1, dim, h2, w2 = corr.shape
        corr = corr.reshape(batch*h1*w1, dim, h2, w2)

        self.corr_pyramid.append(corr)
        for i in range(self.num_levels-1):
            corr = F.avg_pool2d(corr, 2, stride=2)
            self.corr_pyramid.append(corr)

This is used to reshape centroid_lvl in method call(...) to [2775, 1, 1, 2]), which is then used in bilinear_sampler() with its shape being torch.Size([2775, 1, 1, 2]), which in turn gives the shape to coords_lvl

II. corr <- fmaps:

corr eventually takes its shape from fmaps, as detailed in ch.III

Fmaps are generated (in this case) using a BasicEncoder(nn.Module). Looking at the forward(self, x) method,

    def forward(self, x):

        # if input is list, combine batch dimension
        is_list = isinstance(x, tuple) or isinstance(x, list)
        if is_list:
            batch_dim = x[0].shape[0]
            x = torch.cat(x, dim=0)

        x = self.conv1(x)
        x = self.norm1(x)
        x = self.relu1(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        x = self.conv2(x)

        if self.training and self.dropout is not None:
            x = self.dropout(x)

        if is_list:
            x = torch.split(x, [batch_dim, batch_dim], dim=0)
        return x

It returns a feature map from an image, by passing it through different nn layers, with its shape at the end being exactly torch.Size([256, 38, 75]). This shape is propagates as described in chapter III.

III. fmap1, fmap2 -> CorrBlock.init(self, fmap1, fmap2, num_levels=4, radius=4) -> corr.shape

In raft.py, method self.fnet([image1, image2]) return fmap1 and fmap2, with shapes torch.Size([1, 256, 38, 75])

    with autocast(enabled=self.args.mixed_precision):
        fmap1, fmap2 = self.fnet([image1, image2])

This propagates to CorrBlock.init() to corr object with shape torch.Size([2850, 1, 38, 75]) where 2850=38*75

It is then appended to self.corr_pyramid, to finally be used in the call(), in bilinear_sampler() with its shape being torch.Size([2850, 1, 1, 2])

IV. Possible solutions:

To match the same shape, either the small CNN must be modified, or the way the grid shape its defined in initialize_flow(), from:

coords0 = coords_grid(N, H//8, W//8).to(img.device)

to:

coords0 = coords_grid(N, np.round(H/8), np.round(W/8))).to(img.device)

However, I suspect this change should be made at other points in the implementation as well.

V. Observation

This method of dividing by 8 to match the output shape of the convolutions can raise multiple errors, and should better match exactly the output shape. In case of modifying the architecture of the CNN, this will also throw shape miss-match errors.

gaochen315 commented 3 years ago

Hi @Edward334, thanks for this detailed report! Yes, the image inpainting network requires that the height and weight should be divided by 8.

One trivial solution is to resize the image first, inpaint the missing region, and resize the image back. This is not ideal but can bypass the error. I'll try to find a better inpainting method.

CreativeSelf0 commented 2 years ago

Is this fixed?