fxia22 / stn.pytorch

pytorch version of spatial transformer networks
Other
587 stars 87 forks source link

BCHW format #7

Open thnkim opened 7 years ago

thnkim commented 7 years ago

Excellent work!

I would like to use this in the middle of my pytorch network, so my tensors are in [Batch x Channel x Height x Width] format. I tried to use torch.permute to change their dimension orders, but it was not successful. For example, when a = torch.randn((2,3,4,5)), a.stride() is (60, 20, 5, 1), but if I do b = a.permute((0,1,2,3)), b.stride() is (1, 60, 20, 5) while torch.randn(5,2,3,4).stride() is (24, 12, 4, 1).

Is there an easy and efficient way to do it? or do I need to change .c and .cu files in src?

I guess a.permute((0,1,2,3)).contiguous() might be a solution, but I'm not sure it is safe for Variable (autograd).

Thank you.

fxia22 commented 7 years ago

You can use transpose: img = img.transpose(1,2).transpose(2,3), this should change BCHW layout to BHWC

thnkim commented 7 years ago

But, transpose(1,2).transpose(2,3) seems not to rearrange the internal array. torch.FloatTensor(1,2,3,4).stride() and torch.FloatTensor(1,2,3,4).transpose(1,2).transpose(2,3).stride() are (24, 12, 4, 1) and (24, 4, 1, 12), respectively, while torch.FloatTensor(1,3,4,2).stride() is (24, 8, 2, 1).

So if I run the code, at line 44 and 45 in my_lib.c,

real yf = grids_data[b*grids_strideBatch + yOut*grids_strideHeight + xOut*grids_strideWidth];
real xf = grids_data[b*grids_strideBatch + yOut*grids_strideHeight + xOut*grids_strideWidth + 1];

xf is not valid, because grids_strideWidth is still 1.

I guess it needs to be like

real xf = grids_data[b*grids_strideBatch + yOut*grids_strideHeight + xOut*grids_strideWidth + 1*grids_strideChannel];

although I have not tested it.

fxia22 commented 7 years ago

transpose(1,2).transpose(2,3) changes the internal array, you can use .size() to check, I have been using that all the time. test_conv_stn.ipynb actually uses that fyi.

On a separate note, I guess BCHW should be the standard because it follows pytorch conv layers convention. I probably will have a version for that later. Let me know what you think.

fxia22 commented 7 years ago

Oh sorry I misunderstood, you are talking about permutation for grid rather than image. Hmm, I always use the grid generator to generate grid in BHWC format directly so never run into the problem you mentioned.

thnkim commented 7 years ago

Thank you. Meanwhile I'll use permute (or transpose) and then contiguous(). It seems to work properly so far :)

junyanz commented 7 years ago

Thank Fei for the nice work. Do you have any update on BCHW support?

fxia22 commented 7 years ago

@junyanz Hi Junyan, thank you for your interest, it is likely to be added after the NIPS deadline. We do find the majority of users need BCHW instead of BHWC and will thus prioritize it :D .

junyanz commented 7 years ago

@fxia22 Thanks for your prompt response. Good luck with your NIPS submission.

fxia22 commented 7 years ago

BCHW support added. example can be found in test.py

junyanz commented 7 years ago

Thanks a lot!

edgarriba commented 7 years ago

@fxia22 to go from BHWC to BCHW just use img.permute(0, 3, 1, 2)

fxia22 commented 7 years ago

@edgarriba Thanks for your suggestion. As discussed above, the problem of BCHW for STN is that BCHW layout is not suitable for coalescing. Permutation itself doesn't change the memory layout, but .contiguous() after permute will work.

edgarriba commented 7 years ago

ah, right. Permute just recompute strides