NVIDIA / flownet2-pytorch

Pytorch implementation of FlowNet 2.0: Evolution of Optical Flow Estimation with Deep Networks
Other
3.12k stars 738 forks source link

Why are test images forced to be cropped? #60

Open TomHeaven opened 6 years ago

TomHeaven commented 6 years ago

I find that my test images are forced to be cropped. There is a cropping procedure in datasets.py to ensure the size of input images to be multiple of 64.

My concerns are about:

Is the cropping procedure necessary? Is there a workaround to run Flownet2 on arbitrary size of images except for resizing test images? It seems that there is no such a limitation in the original repo https://github.com/lmb-freiburg/flownet2.

HuangJunJie2017 commented 6 years ago

@TomHeaven

this condition is used to make sure that the sizes of torch.cat elements are all the same, as the size of feature map obtained by 'torch.upsample' after 'torch.conv' will change if it is not multiple of 2^6=64. All the submodules of flownet2 have 6 'torch.conv' whose stride is 2.

if you want to use arbitrary image size , you can add 'torch.zeropad' layer to refine the size

e.g.

define a class for padding operation

class padding(nn.Module): def __init(self): super(padding,self).init__() self.wpad = nn.ReplicationPad2d((0, -1, 0, 0)) self.hpad = nn.ReplicationPad2d((0, 0, 0, -1))

def forward(self, input, targetsize):
    if input.size()[2] != targetsize[2]:
        input = self.hpad(input)
    if input.size()[3] != targetsize[3]:
        input = self.wpad(input)
    return input

define an instance of this class in the initial of Flownet modules

self.pad = padding()

pad before torch.cat

    flow6 = self.predict_flow6(out_conv6)
    flow6_up = self.upsampled_flow6_to_5(flow6)
    out_deconv5 = self.deconv5(out_conv6)
    #pad
    flow6_up = self.pad(flow6_up, out_conv5.size())
    out_deconv5 = self.pad(out_deconv5, out_conv5.size())
    #pad
    concat5 = torch.cat((out_conv5, out_deconv5, flow6_up), 1)
    out_interconv5 = self.inter_conv5(concat5)
    flow5 = self.predict_flow5(out_interconv5)
tlatlbtle commented 6 years ago

mark, good issue.

ghost commented 6 years ago

will padding operation lead the accuracy dropping down?

riccardosamperna commented 6 years ago

I am also interested in this issue.

@HuangJunJie2017 do you think the accuracy of the optical flow is going to change?

Do I have to re-train the model in order to work with padded images?

HuangJunJie2017 commented 6 years ago

@riccardosamperna to the best of my knowledge, the accuracy will not change and the model need not to be re-trained,as when i visualize the result, it don't change after adding padding procedure. Considering the optical flow estimation is based on convolution, which is regional. Padding only impacts the edge.

riccardosamperna commented 6 years ago

@HuangJunJie2017

Following your advices I made it work. Thanks a lot.

TomHeaven commented 6 years ago

@HuangJunJie2017 @riccardosamperna Thanks a lot. Maybe we can make a pull request to handle images with arbitrary sizes to close the issue.

riccardosamperna commented 6 years ago

Good idea @TomHeaven, when I have some time I can try to make a pull request with what I have so far.

EnQing626 commented 6 years ago

Thanks @HuangJunJie2017 , It works for me after adding the padding in model~

huangbiubiu commented 5 years ago

Padding is needed in every torch.cat operation (including concat5, concat4, concat3, concat2, not only concat5), is that correct?

EnQing626 commented 5 years ago

@huangbiubiu correct.

aaab8b commented 5 years ago

@riccardosamperna to the best of my knowledge, the accuracy will not change and the model need not to be re-trained,as when i visualize the result, it don't change after adding padding procedure. Considering the optical flow estimation is based on convolution, which is regional. Padding only impacts the edge.

hi junjie, i am wondering have you ever tried to pad zeros to the conv layers instead of cropping the deconv and flow-up feature map? what is the result?

HuangJunJie2017 commented 5 years ago

@aaab8b not yet, you can try it and compare the result. In this way, you need to dual with a more complex problem, that is match the size of the inputs of the concat operation. And as a result, you will get a optical flow field whose size is bigger than the input

aaab8b commented 5 years ago

@aaab8b not yet, you can try it and compare the result. In this way, you need to dual with a more complex problem, that is match the size of the inputs of the concat operation. And as a result, you will get a optical flow field whose size is bigger than the input I actually change your single-pair code like this to run a demo (using the branch of pytorch-0.4) `import torch import numpy as np import argparse

from networks.FlowNetS import FlowNetS # the path is depended on where you create this module from utils.frame_utils import read_gen # the path is depended on where you create this module

def main():

obtain the necessary args for construct the flownet framework

parser = argparse.ArgumentParser()
parser.add_argument('--fp16', action='store_true', help='Run model in pseudo-fp16 mode (fp16 storage fp32 math).')
parser.add_argument("--rgb_max", type=float, default=255.)
args = parser.parse_args()

# initial a Net
net = FlowNetS(args=args,input_channels=6,batchNorm=False).cuda()
# print("net structrue={}".format(net))
# load the state_dict
dict = torch.load("/data/bingzd/models/FlowNet2-S_checkpoint.pth.tar")
net.load_state_dict(dict["state_dict"])

# load the image pair, you can find this operation in dataset.py
pim1 = read_gen("/data/Database_pub/000050.jpg")
pim2 = read_gen("/data/Database_pub/000051.jpg")
# images = [pim1, pim2]
# images = np.array(images).transpose(3, 0, 1, 2)
images=torch.cat((torch.from_numpy(pim1.astype(np.float32)).permute(2,0,1),torch.from_numpy(pim2.astype(np.float32)).permute(2,0,1)),dim=0)
net.eval()
im = images.unsqueeze(0).cuda()
# process the image pair to obtian the flow
result = net(im)[0].squeeze()
print("result shape={}".format(result.shape))
data = result.data.cpu().numpy()

`

the shape of input is 1920x1080, but when i check with the shape of the result flow, it turns to be 480x270. do you have any idea of it? And are the values of the result dense flow be the delta-x and delta-y for original coordinates or it should be divided by 1080/270 = 4?

HuangJunJie2017 commented 5 years ago

@aaab8b the resolution of initial result is 1/4 of input , you can add upsampling layer with a factor of 4 to obtain a result with same resolution of input

aaab8b commented 5 years ago

@aaab8b the resolution of initial result is 1/4 of input , you can add upsampling layer with a factor of 4 to obtain a result with same resolution of input

thank you so much

huangbiubiu commented 5 years ago

@HuangJunJie2017 What if the input width (or height) cannot be divisible by 4 (e.g. 270 / 4 = 67.5)? There will be an error at diff_img0 = x[:, :3, :, :] - resampled_img1 in FlowNet2:

RuntimeError: The size of tensor a (270) must match the size of tensor b (272) at non-singleton dimension 2

Is there any workaround?

HuangJunJie2017 commented 5 years ago

@TomHeaven

this condition is used to make sure that the sizes of torch.cat elements are all the same, as the size of feature map obtained by 'torch.upsample' after 'torch.conv' will change if it is not multiple of 2^6=64. All the submodules of flownet2 have 6 'torch.conv' whose stride is 2.

if you want to use arbitrary image size , you can add 'torch.zeropad' layer to refine the size

e.g.

define a class for padding operation

class padding(nn.Module): def init(self): super(padding,self).init() self.wpad = nn.ReplicationPad2d((0, -1, 0, 0)) self.hpad = nn.ReplicationPad2d((0, 0, 0, -1))

def forward(self, input, targetsize):
    if input.size()[2] != targetsize[2]:
        input = self.hpad(input)
    if input.size()[3] != targetsize[3]:
        input = self.wpad(input)
    return input

define an instance of this class in the initial of Flownet modules

self.pad = padding()

pad before torch.cat

    flow6 = self.predict_flow6(out_conv6)
    flow6_up = self.upsampled_flow6_to_5(flow6)
    out_deconv5 = self.deconv5(out_conv6)
    #pad
    flow6_up = self.pad(flow6_up, out_conv5.size())
    out_deconv5 = self.pad(out_deconv5, out_conv5.size())
    #pad
    concat5 = torch.cat((out_conv5, out_deconv5, flow6_up), 1)
    out_interconv5 = self.inter_conv5(concat5)
    flow5 = self.predict_flow5(out_interconv5)

@huangbiubiu refer to my previous comment,where uses padding operation to fix this problem

huangbiubiu commented 5 years ago

@HuangJunJie2017 Thanks for your great solution! However, I don't think adding pad before concat can fix my issue.

I think adding pad can avoid error at concat, but consider an input with size (3, 270, 270) in FlowNet2 module. After flownetc:

https://github.com/NVIDIA/flownet2-pytorch/blob/252686cb83336fc969ae45d0769ec4a089745839/models.py#L118-L119

The shape of flownetc_flow2 turns to be (1, 2, 68, 68) (1/4 of input, since 270 / 4 = 67.5). After upsampling, the shape of flownetc_flow is (1, 2, 272, 272).

When calculating diff:

https://github.com/NVIDIA/flownet2-pytorch/blob/252686cb83336fc969ae45d0769ec4a089745839/models.py#L121-L124

The shape of resampled_img1 is (1, 3, 272, 272), which is incompatible with x[:, :3, :, :], so diff_img0 = x[:, :3, :, :] - resampled_img1 will raise an error. I do add padding in FlowNetC but it looks that the problem will not be fixed.

Is there any solution to that? Or did I miss something important?

HuangJunJie2017 commented 5 years ago

@huangbiubiu when you come across some problem like that, you just crop or pad the tensor until they match each other. Another way to solve this problem is to pad the input image so that its resolution is multiple of 64, and crop the output into desire size. input270>>input320 >>output320>>output270

Lywzz commented 5 years ago

@HuangJunJie2017 i have defined a class for padding operation but it also has a problem AttributeError: 'FlowNetC' object has no attribute 'inter_conv5' please have a look ,thanks

HuangJunJie2017 commented 5 years ago

@Lywzz my solution is a template specific for FlowNet2, you have to understand it and change something before you use it in somewhere else(i.e. FlowNetC).

Lywzz commented 5 years ago

@HuangJunJie2017 i used FlowNet2 model 640M ,so i don't know what's wrong with me thanks for your reply

HuangJunJie2017 commented 5 years ago

@Lywzz
analysis the error "AttributeError: 'FlowNetC' object has no attribute 'inter_conv5'" it means you use inter_conv5 in FlowNetC model but FlowNetC has no attribute 'inter_conv5' 'inter_conv5' is a attribute of FlowNet2 and not include in FlowNetC you need to understand what it is necessary for padding operation and try not to just copy all the template code 'out_interconv5 = self.inter_conv5(concat5) flow5 = self.predict_flow5(out_interconv5)' in template is unnecessary for FlowNetC

TyroneLi commented 4 years ago

@HuangJunJie2017 Thanks for your great solution! However, I don't think adding pad before concat can fix my issue.

I think adding pad can avoid error at concat, but consider an input with size (3, 270, 270) in FlowNet2 module. After flownetc:

https://github.com/NVIDIA/flownet2-pytorch/blob/252686cb83336fc969ae45d0769ec4a089745839/models.py#L118-L119

The shape of flownetc_flow2 turns to be (1, 2, 68, 68) (1/4 of input, since 270 / 4 = 67.5). After upsampling, the shape of flownetc_flow is (1, 2, 272, 272).

When calculating diff:

https://github.com/NVIDIA/flownet2-pytorch/blob/252686cb83336fc969ae45d0769ec4a089745839/models.py#L121-L124

The shape of resampled_img1 is (1, 3, 272, 272), which is incompatible with x[:, :3, :, :], so diff_img0 = x[:, :3, :, :] - resampled_img1 will raise an error. I do add padding in FlowNetC but it looks that the problem will not be fixed.

Is there any solution to that? Or did I miss something important?

Hi~How did you fix this?I met the same error,and I think simply resizing input or cropping output will cause tiny impact on optical flow result.

lihaolin1 commented 4 years ago

Can I just change the code in dataset.py "//64) 64" to "//32) 32" ? Just to let the code can process 160*128 size image, and will it influence the performance of the result?