ltkong218 / FastFlowNet

FastFlowNet: A Lightweight Network for Fast Optical Flow Estimation (ICRA 2021)
MIT License
261 stars 42 forks source link

Results are inconsistent with expectations #5

Closed dongxuanlb closed 3 years ago

dongxuanlb commented 3 years ago

Platform: xavier. jetpack4.4.1 cuda:10.2 pytorch:1.6

the correlation_packege in this repo is for pytorch0.4 which is not suite for pytorch 1.6, so I used flownet2's (https://github.com/NVIDIA/flownet2-pytorch/tree/master/networks/correlation_package)

then everything is ok.

the output is :

/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:3384: UserWarning: Default grid_sample and affine_grid behavior has changed to align_corners=False since 1.3.0. Please specify align_corners=True if the old behavior is desired. See the documentation of grid_sample for details. warnings.warn("Default grid_sample and affine_grid behavior has changed "

but the result is not same in the repository.

flow

ltkong218 commented 3 years ago

You can try the Pytorch Correlation module, which supports newer versions of PyTorch, such as 1.2 and 1.6, and refer to the issue Installing Correlation package. Also, please note the implementation of warping layer.

dongxuanlb commented 3 years ago

Hi, I had installed the Pytorch Correlation module, but I noticed that the "Pytorch Correlation module " construct is not equal with origin.

SpatialCorrelationSampler(1, 9, 1, 0, 1) <-> Correlation(pad_size=4, kernel_size=1, max_displacement=4, stride1=1, stride2=1, corr_multiply=1)

ltkong218 commented 3 years ago
import torch
from spatial_correlation_sampler import SpatialCorrelationSampler

input1 = torch.randn(2, 32, 48, 64).cuda()
input2 = torch.randn(2, 32, 48, 64).cuda()

# define a correlation module
correlation_sampler = SpatialCorrelationSampler(1, 9, 1, 0, 1)

output = correlation_sampler(input1, input2)

# reshape output to be a 3D cost volume
b, c, h, w = input1.shape
output = output.view(b, -1, h, w) / c

print(output.shape)

I have checked that the Correlation module in my repository (PyTorch=0.4.1) and SpatialCorrelationSampler module (PyTorch=1.2) can output the same results. Remember to reshape the direct output of SpatialCorrelationSampler module and divide it by channel numbers. Also, you should check whether the input tensors are the same.