ClementPinard / FlowNetPytorch

Pytorch implementation of FlowNet by Dosovitskiy et al.
MIT License
842 stars 206 forks source link

KITTI flow reading is not correct #23

Closed mathmanu closed 6 years ago

mathmanu commented 6 years ago

Thank you for this simple and readable code. I am also glad that this works with python 3.5 pytorch 0.3. I look forward to you adding other networks such as FlowNetC and FlowNet2.0. Also other metrics such as percentage outliers would be a great addition.

KITTI flow GT is sparse, but this is not considered in flow reading or in training. I suggest the following changes..

In KITTI.py, the ground truth flow reading is not correct. By looking at the KITTI flow reading script and the readme.txt there, this is what I wrote.

def load_flow_from_png(png_path):

# read using cv2 and convert from bgr to rgb
# scipy cannot handle 16 bit images, hence cv2 is used.
flo_img = cv2.imread(png_path,-1)
flo_img = flo_img[:,:,::-1].astype(float)

# see the readme file in KITTI devkit and the flow reader functions
mask = np.minimum(flo_img[:,:,2],1)
not_valid = (mask == 0)
valid = (mask != 0)
flo_img = flo_img[:, :, 0:2]
flo_img = flo_img - 32768
flo_img = flo_img / float(64.0)

# value 0 is used to indicate invalid flow.
# flow that is actually valid and zero is set to a very small value
eps = 1e-10
flo_img[np.abs(flo_img) < eps] = eps

# invalid flow is indicated by 0
flo_img[not_valid, :] = float(0.)
return flo_img

Apart from the above function, the sparse flag has to be passed into several functions. I added a flag called sparse_gt

if args.sparse_gt is None: args.sparse_gt = ('KITTI' in args.dataset)

and this flag is passed to all the relevant functions such as: multiscaleEPE, one_scale, realELE, EPE etc.

With these changes, I am getting more meaningfull EPE values.

Kindly fix this issue.

mathmanu commented 6 years ago

Sparse resampling using maxpool will not be correct in the case of sparse flow for negative values. So I introduced an adaptive_minmax_pool2d() that takes care of negative values. With this I am able to see negative values are also learned properly in the case of sparse GT of KITTI.

def adaptive_minmax_pool2d(input, size): maskp = (input>=0).float() maskn = (input<0).float() output = nn.functional.adaptive_max_pool2d(input maskp, size) - \ nn.functional.adaptive_max_pool2d(-input maskn, size) return output

def multiscaleEPE(network_output, target_flow, weights=None, sparse=False): def one_scale(output, target, sparse):

    b, c, h, w = output.size()

    if sparse:
        #target_scaled = nn.functional.adaptive_max_pool2d(target, (h, w))
        target_scaled = adaptive_minmax_pool2d(target, (h, w))
    else:
        target_scaled = nn.functional.adaptive_avg_pool2d(target, (h, w))
    return EPE(output, target_scaled, sparse=sparse, mean=True)

if type(network_output) not in [tuple, list]:
    network_output = [network_output]
if weights is None:
    weights = [0.005,0.01,0.02,0.08,0.32]  # as in original article
assert(len(weights) == len(network_output))

loss = 0
for output, weight in zip(network_output, weights):
    loss += weight * one_scale(output, target_flow, sparse)
return loss
mathmanu commented 6 years ago

[updated]

I also did a small change due to a crash in the function EPE() valid = (mask == False) EPE_map = EPE_map[valid]

I have a question: In function one_scale() under multiscaleEPE(), shouldn't the mean be set to True? May be because sparse GT was not handled correctly. Now that sparse GT is handled correctly with the above changes, we can go back to mean=True in this function. return EPE(output, target_scaled, sparse=sparse, mean=True)

ClementPinard commented 6 years ago

Hello, thanks a lot for your issue ! KITTI flow learning was not tested, only written as mockup to improve. So your help is much appreciated !

For mean EPE, it is actually meant to be that way, see #17 : EPE can be averaged, but you need to change scale weights for the training to be like original implementation.

A patch with KITTI is coming. Or you might want to make a PR to be credited for that, up to you !

Clément

mathmanu commented 6 years ago

Hi, I am glad that I could contribute. Mean EPE with appropriate scale weights seem to be more logical.

mathmanu commented 6 years ago

As I am behind a firewall, making PRs may be more involved. It is easiest for me to contribute this way over issues.

ClementPinard commented 6 years ago

After some calculations, equivalent weights for training in EPE is changed from [0.005,0.01,0.02,0.08,0.32] to [716.8,358.4,179.2,179.2,179.2]

I am not particularily happy with thoses ugly and gigantic numbers, and I prefer to keep it the same as original implementation :p I might add an addendum to the README to make it clear that EPE is not the same when training and when validating (hence the real EPE metric)

mathmanu commented 6 years ago

Hi, I am not able understand this change. I changed to using mean loss for both train and val and I am getting comparable loss for training and validation. I also had to make sure that the sparse flag is passed down correctly into all functions. In both training and validation, loss will be calculated as the mean of only the points where target is available (sparse). Hence the values will be comparable.

Am I missing something? Could you show some more details of your computation?

ClementPinard commented 6 years ago

The value to be compared between training and validation is flow2_EPEs, it computes the main EPE for full resolution target flow maps. It is more meaningful than training loss, which deals with several different scales, but not event the full resolution one. When training you receive both losses and flow2_EPEs. You must ensure that losses is going down, but only flow2_EPEs can give you a good hint of how good the network performs.

I have commited some changes based on your comments : invalid flow values will be set to NaN, because MaxPooling does not propagate Nan values. This allows us to not design a custom pooling layer. Hope it'll be good for you!

mathmanu commented 6 years ago

I have couple of questions:

Question 1:

   # invalid flow is defined with one of flow coordinate to be NaN
    mask = (target_flow[:,0] != target_flow[:,0]) | (target_flow[:,1] != target_flow[:,1])

    EPE_map = EPE_map[~mask.detach()]

Mask indicates the positions of invalid flow. Then ~mask indicates positions of valid flow. Why are you detaching the valid positions? Shouldn't it be the other way?

[UPDATE] Thinking more about, I think I got the answer. What is detached is the mask (and not EPE_map). Looks good. Although I wonder whether there is a need to detach this mask which is based on target?

Question 2:

In the the function one_scale(), in the following line: return EPE(output, target_scaled, sparse, mean=False) Are you planning to set mean=True?

Question 3:

Could you please explain how you arrived at values: [716.8,358.4,179.2,179.2,179.2] I was not able to understand this part. These seem too high and would be the equivalent of using a high learning rate?

ClementPinard commented 6 years ago

Sorry, I have been a little too quick on patching, another one is arriving.

Detaching only means that you detach the variable from the gradient graph, that means you don't backpropagate the gradient. It makes sense here because mask is not differentiable. I ended taking the tensor from variable, because I kept having problems.

I am not planning on making training loss with mean EPE, because , as question3 shows, equivalent coeff are too high and don't seem very natural.

When computing sum instead of mean, you don't divide by the number of elements. That means if you have HW pixels, your loss (and gradient) will be HW higher. The sum actually means over batch, in order to keep the same order of magnitude with differing batch size. So to pass from [0.005,0.01,0.02,0.08,0.32] to equivalent gradient, knowing that input size during training is [320,448], you multiply every coeff by (320*448)/(downscale_factor), which is a lot.

mathmanu commented 6 years ago

Thanks for the explanation.

Should it be sum() or mean() for the loss function - I have been seeing mean() everywhere - including SfMLearner-Pytorch :) def photometric_reconstruction_loss( . . reconstruction_loss += diff.abs().mean()

The following thread explains what I was saying about the chance of exploding gradients due to high equivalent learning rate if you use sum: https://stackoverflow.com/questions/41954308/loss-function-works-with-reduce-mean-but-not-reduce-sum

mathmanu commented 6 years ago

@ClementPinard Regarding the line: EPE_map = EPE_map[~mask.detach()]

~mask gave me some error in the past. I would have to use something like: mask.detach() EPE_map = EPE_map[mask==False]

ClementPinard commented 6 years ago

Sorry for the lack of response, I am working on it. In the mean time, i advise you to stay on your own fix. It will be fixed tomorrow.

mathmanu commented 6 years ago

Another observation - the cropping co_transform was removed from KITTI test split. OLD: test_dataset = ListDataset(root, test_list, transform, target_transform, flow_transforms.CenterCrop((320,1216)), loader=KITTI_loader) NEW: test_dataset = ListDataset(root, test_list, transform, target_transform, loader=KITTI_loader)

This cases the following error (since he KITTI images are not of the same size):

File "main.py", line 417, in main() File "main.py", line 209, in main EPE = validate(val_loader, model, epoch, output_writers) File "main.py", line 288, in validate for i, (input, target) in enumerate(val_loader): File "/user/a0393608/files/apps/anaconda2/envs/pytorch/lib/python3.5/site-packages/torch/utils/data/dataloader.py", line 210, in next return self._process_next_batch(batch) File "/user/a0393608/files/apps/anaconda2/envs/pytorch/lib/python3.5/site-packages/torch/utils/data/dataloader.py", line 230, in _process_next_batch raise batch.exc_type(batch.exc_msg) RuntimeError: Traceback (most recent call last): File "/user/a0393608/files/apps/anaconda2/envs/pytorch/lib/python3.5/site-packages/torch/utils/data/dataloader.py", line 42, in _worker_loop samples = collate_fn([dataset[i] for i in batch_indices]) File "/user/a0393608/files/apps/anaconda2/envs/pytorch/lib/python3.5/site-packages/torch/utils/data/dataloader.py", line 119, in default_collate return [default_collate(samples) for samples in transposed] File "/user/a0393608/files/apps/anaconda2/envs/pytorch/lib/python3.5/site-packages/torch/utils/data/dataloader.py", line 119, in return [default_collate(samples) for samples in transposed] File "/user/a0393608/files/apps/anaconda2/envs/pytorch/lib/python3.5/site-packages/torch/utils/data/dataloader.py", line 119, in default_collate return [default_collate(samples) for samples in transposed] File "/user/a0393608/files/apps/anaconda2/envs/pytorch/lib/python3.5/site-packages/torch/utils/data/dataloader.py", line 119, in return [default_collate(samples) for samples in transposed] File "/user/a0393608/files/apps/anaconda2/envs/pytorch/lib/python3.5/site-packages/torch/utils/data/dataloader.py", line 96, in default_collate return torch.stack(batch, 0, out=out) File "/user/a0393608/files/apps/anaconda2/envs/pytorch/lib/python3.5/site-packages/torch/functional.py", line 66, in stack return torch.cat(inputs, dim, out=out) RuntimeError: inconsistent tensor sizes at /opt/conda/conda-bld/pytorch_1512382878663/work/torch/lib/TH/generic/THTensorMath.c:2864

mathmanu commented 6 years ago

There are one problems with the new implementation in KITTI.py. Wrong channel is used as validity. def load_flow_from_png(png_path): flo_file = cv2.imread(png_path,cv2.IMREAD_UNCHANGED) flo_img = flo_file[:,:,1::-1].astype(np.float32) invalid = (flo_file[:,:,2] == 0) flo_img = flo_img - 32768 flo_img = flo_img / 64 flo_img[invalid, :] = np.NaN return(flo_img)

Wrong channel is used for validity. It should be: invalid = (flo_file[:,:,0] == 0)

I also suggest to change the implementation to the following to be clear: def load_flow_from_png(png_path): flo_file = cv2.imread(png_path,cv2.IMREAD_UNCHANGED) flo_file = flo_file[...,::-1] # bgr => rgb order flo_img = flo_file[...,:2].astype(np.float32) # first two channels are flow invalid = (flo_file[:,:,2] == 0) # last channel is validity flo_img = flo_img - 32768 flo_img = flo_img / 64 flo_img[invalid, :] = np.NaN return(flo_img)

Once the above is corrected, and the flow is actually made sparse according to the validity, the default maxpooling is no longer working as expected (causes problem in mean/sum later due to NaN). However there is no issue if I use the following function in multiscaleEPE()

def adaptive_minmax_pool2d(input, size): maskp = (input>=0).float() maskn = (input<0).float() output = nn.functional.adaptive_max_pool2d(input maskp, size) - \ nn.functional.adaptive_max_pool2d(-input maskn, size) return output

def multiscaleEPE(network_output, target_flow, weights=None, sparse=False): def one_scale(output, target, sparse): if sparse: target_scaled = adaptive_minmax_pool2d(target, (h, w)) else: target_scaled = nn.functional.adaptive_avg_pool2d(target, (h, w))

mathmanu commented 6 years ago

My implementation of adaptive_minmax_pool2d() is not correct in the presence of NaN (arithmetic operations over NaN does not give correct value) - this explains why I am not getting a crash in mean() or sum(). For now I'll continue to indicate invalid positions by zeros.

ClementPinard commented 6 years ago

Ths should all be solved. Sorry for the negligent push last monday. I did some tests, and network seems to learn properly now. I am currently running a FlowNetS finetuning on KITTI_noc

mathmanu commented 6 years ago

@ClementPinard Thank you for the fix. I'll try it. And sorry if I troubled you with too many questions.

I have one comment: In KITTI.py in function KITTI_occ(), flow transform is not included flow_transforms.CenterCrop((370,1224)), (it is added in KITTI_noc):

ClementPinard commented 6 years ago

Solved, thanks for you comment !