fxia22 / stn.pytorch

pytorch version of spatial transformer networks
Other
587 stars 87 forks source link

output blank image #11

Closed edgarriba closed 7 years ago

edgarriba commented 7 years ago

is there any reason why STNFunctionBCHW() would return a blank image? Running the cpu version.

fxia22 commented 7 years ago

mine works. Can you run this code to reproduce the problem? Thanks. https://gist.github.com/fxia22/d06234089de26d62f3e60b98babf5111

edgarriba commented 7 years ago

ops forgot to mention that I'm using a custom grid. Tests seem to pass, so probably I'm doing something wrong

fxia22 commented 7 years ago

The grid coordinates need to be in [-1,1].

edgarriba commented 7 years ago

@fxia22 thanks for the answer and sorry for late reply, but unfortunately by normalizing between [-1, 1] doesn't solve the issue at all.

I'm not getting a blank image anymore, however, the routine returns me the original image. Take a look at the snippet below. I would expect to shift the whole image 150 pixels to the right.

img = imread(...)  # NxCxHxW
H, W = img.size()[-2:]

# create a custom mapping
maps = Variable(torch.from_numpy(np.indices((H, W)).reshape(2, -1)).float())
maps = maps.floor().view(2, H, W)  # 2xHxW

# shift coordinates
maps[0] = maps[0] + 150  # x
maps[1] = maps[1] + 0    # y

# normalize maps between [-1, 1]
maps_norm = normalize(maps, -1, 1)  # tested and working !

# remap image !
out = STN('BCHW')(img, maps_norm.unsqueeze(0))
fxia22 commented 7 years ago

I can see why that doesn't work, when you normalize it, everything goes back to [-1, 1] as if you didn't do any change. You can try the following snippet (I haven't tried yet):

img = imread(...)  # NxCxHxW
H, W = img.size()[-2:]

# create a custom mapping
maps = Variable(torch.from_numpy(np.indices((H, W)).reshape(2, -1)).float())
maps = maps.floor().view(2, H, W)  # 2xHxW

# shift coordinates
maps[0] = maps[0] + 150 / float(W) * 2 # x
maps[1] = maps[1] + 0    # y

# remap image !
out = STN('BCHW')(img, maps.unsqueeze(0))
edgarriba commented 7 years ago

black image

fxia22 commented 7 years ago

oh sorry, should be like this:

You need to first normalize to [-1,1] and then operate based on that.

img = imread(...)  # NxCxHxW
H, W = img.size()[-2:]

# create a custom mapping
maps = Variable(torch.from_numpy(np.indices((H, W)).reshape(2, -1)).float())
maps = maps.floor().view(2, H, W)  # 2xHxW

maps_norm = normalize(maps, -1, 1)  # tested and working !

# shift coordinates
maps_norm[0] = maps_norm[0] + 150 / float(W) * 2 # x
maps_norm[1] = maps_norm[1] + 0    # y

# remap image !
out = STN('BCHW')(img, maps_norm.unsqueeze(0))
edgarriba commented 7 years ago

oh yeah, now it works ! Should the STN module somehow handle all this?

fxia22 commented 7 years ago

Adding support for leting STN to handle this is not planned. I prefer not to operate using the metric of pixels directly.

edgarriba commented 7 years ago

ok, cool ! thanks for your help