creotiv / hdrnet-pytorch

Unofficial PyTorch implementation of 'Deep Bilateral Learning for Real-Time Image Enhancement', SIGGRAPH 2017 https://groups.csail.mit.edu/graphics/hdrnet/
228 stars 45 forks source link

Discuss about the grid_sample #5

Open qimw opened 4 years ago

qimw commented 4 years ago

The current code implement the Slice operation through F.grid_sample(). But I found the indexes may be wrong.

import torch.nn.functional as F
import numpy as np

box = np.ones((10, 10, 10))
for i in range(10):
    box[:,:,i] = i
box = torch.from_numpy(box).view(1,1,10,10,10).repeat(1,5,1,1,1)
indices = np.array([0.5,0,0]).reshape(1,1,1,1,3)
indices[0,0,0,0,0] = 0.5
indices = torch.from_numpy(indices)
print(box.shape)
print(indices.shape)
print(F.grid_sample(box, indices)) 

Let's assume that the box to be bilateral_grid and indices to be guidemap. In this code we can find that indices[0,0,0,0,0], indices[0,0,0,0,1], indices[0,0,0,0,2] are corresponding to the fifth, fourth and third dimension of box respectively. So the code should be guidemap_guide = torch.cat([wg, hg, guidemap], dim=3).unsqueeze(1) # Nx1xHxWx3 instead of guidemap_guide = torch.cat([guidemap, hg, wg], dim=3).unsqueeze(1) # Nx1xHxWx3

And I found that someone figure out that the F.grid_sample() in fact doing trilinear interpolation when the input is 5D. A issue of pytorch you discussed before

creotiv commented 4 years ago

As for correctness, it's hard to say now cause I don't remember. And also grid_sampling has very poor documentation

As for interpolation no. In not doing trilinear, it doing bilinear but for 3 coordinates, by interpolating 3 pairs (x,y), (y,z), (x,z)

creotiv commented 4 years ago

I will try to look into this bug on weekends

qimw commented 4 years ago

Sorry to bother you, but I don't understand why it isn't trilinear. The code in the url above try to generate interpolation value by surrounding eight values. Am I misunderstand the trilinear?

qimw commented 4 years ago

I try to train the net work on cityscape and foggy cityscape to remove the fog. The results appears to be much better after modifying. :) before: out_before after: out_fix

creotiv commented 4 years ago

I just read the docs. And yeah you were right "In the case of 5D inputs, grid[n, d, h, w] specifies the x, y, z" and guide map is our Z vector

creotiv commented 4 years ago

Updated master. thanks for finding it)

qimw commented 4 years ago

yes! And we still need to make the trilinear clear!

creotiv commented 4 years ago

I made an issue on that but seems nobody cares, and I don't know cuda to write it by myself(

XiaotianM commented 4 years ago

I have same question, too. Why F.grid_sample() doing "bilinear" interpolation when the input is 5D is not trilinear. The following code seems trilinear? https://github.com/pytorch/pytorch/blob/81bf73643b6552a63794fb889238aaf0c2a7baa6/aten/src/ATen/native/GridSampler.cpp#L68-L147

xuqingyu26 commented 4 years ago

yes! And we still need to make the trilinear clear!

Hello, have you figured it out?

creotiv commented 3 years ago

So i think at last i fixed it, problem was in a poor documentation of grid_sample function and few small bugs in the code. Ive tested on the 1 image and it understand how to convert it. Hope soon can show examples trained on my dataset

yaoyuan13 commented 3 years ago

So i think at last i fixed it, problem was in a poor documentation of grid_sample function and few small bugs in the code. Ive tested on the 1 image and it understand how to convert it. Hope soon can show examples trained on my dataset

So is it right that F.grid_sample() is doing trilinear interpolation when the input is 5D? or any issues should be mentioned when use it?

creotiv commented 3 years ago

yep, the problem that configuration saying bi-linear but in real tri-linear.

onpix commented 3 years ago

Hey guys, I have a problem with the order of these three dimensions(wg, hg, guidemap). Why your code is as this:

guidemap_guide = torch.cat([wg, hg, guidemap], dim=3).unsqueeze(1)  # Nx1xHxWx3

instead of this one:

guidemap_guide = torch.cat([hg, wg, guidemap], dim=3).unsqueeze(1)  # Nx1xHxWx3
creotiv commented 3 years ago

Dont remember really, but you are right, should be h,w,3 shape

On Wed, Apr 21, 2021, 6:05 AM HaoyuanWang @.***> wrote:

Hey guys, I have a problem with the order of these three dimensions(wg, hg, guidemap). Why your code is as this:

guidemap_guide = torch.cat([wg, hg, guidemap], dim=3).unsqueeze(1) # Nx1xHxWx3

instead of this one:

guidemap_guide = torch.cat([hg, wg, guidemap], dim=3).unsqueeze(1) # Nx1xHxWx3

— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/creotiv/hdrnet-pytorch/issues/5#issuecomment-823743387, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAB5CDP2MX3G44R7YSATT4DTJY6HHANCNFSM4L4ABEIA .

puneetmatharu commented 3 years ago

I think it was correct in the original implementation; the triplets in the trailing dimension of guidemap_guide contain the (x, y, z) coordinates of a point in the bilateral grid. The tensor wg contains the x coordinate so thus should go first. However, the order won't matter in the current implementation because the spatial resolution is the same in the x and y direction, so they can be permuted freely.

QiuJueqin commented 2 years ago

I think it was correct in the original implementation; the triplets in the trailing dimension of guidemap_guide contain the (x, y, z) coordinates of a point in the bilateral grid. The tensor wg contains the x coordinate so thus should go first. However, the order won't matter in the current implementation because the spatial resolution is the same in the x and y direction, so they can be permuted freely.

Agree. The current implementation in master branch is wrong:

https://github.com/creotiv/hdrnet-pytorch/blob/bfe1c8f706d1617275c8fd00389212c7fbcd93a6/model.py#L53

Instead, this line should be

guidemap_guide = torch.cat([wg, hg, guidemap], dim=3).unsqueeze(1) # Nx1xHxWx3
creotiv commented 2 years ago

ive added bilateral_slice from original repo compiled for jit. But still has some problems with optimization for some reason. So i think grid_sample was working correctly

QiuJueqin commented 2 years ago

After some comparison with my customized tri-linear interpolation, which consists of multiple 2D bilinear interpolation, I'm now pretty sure that the second argument to F.grid_sample (grid) should be something like

torch.cat([wg, hg, guidemap], dim=3).unsqueeze(1)

instead of

torch.cat([hg, wg, guidemap], dim=3).unsqueeze(1)

Furthermore, elements in grid along all axes should be in [-1, 1] range, not [0, 1], which means in the guidance net, the activation should be torch.tanh, instead of torch.sigmoid.

The result of my customized slicing oprator is very similar to the F.grid_sample with inputs formatted mentioned above. The abs error is smaller than 1E-5:

all close with atol=1E-6:  False
all close with atol=1E-5:  True
creotiv commented 2 years ago

@QiuJueqin ive tested both ways and also with new bilateral_slice jit op. it works similar, i mean still has some problem with converging the net. Need to test side by side two solution to find a problem. cause right now all looks similar.

as for activation it should be in [-1,1] but it doesnt mean that [0,1] wil not work, cause [0,1] in [-1,1]

QiuJueqin commented 2 years ago

@creotiv Yes, [0,1] works as well, but half of the valid range in the depth dimension in the bilateral grid is "wasted", which intuitively reduces the model capacity.