fxia22 / stn.pytorch

pytorch version of spatial transformer networks
Other
587 stars 87 forks source link

Example of use in Training #5

Closed melgor closed 7 years ago

melgor commented 7 years ago

Is there a example of STN on any dataset, maybe MNIST? Like here It is not so easy to get it working. This example could explain some ambiguity like:

  1. How to implement auxiliary loss for Grid Generation?
  2. In test_conv_stn.ipynb the very strange stuff are happening like:
    
    conv = self.conv(input2.transpose(2,3).transpose(1,2)) # Why Transpose?
    conv = conv.transpose(1,2).transpose(2,3)  # Why Transpose Back ?
    iden = Variable(torch.cat([torch.ones(1, 328, 582, 1), torch.zeros(1, 328, 582, 3), torch.ones(1, 328, 582, 1), torch.zeros(1, 328, 582, 1)],3 )) # Why we need it?
    out = self.g(conv + iden) # Why we add this values?


I'm planing to try the STN but I'm not sure how I should start.
melgor commented 7 years ago

I'm trying to translate the demo from lua to mnist, but I have strange error. In the lua example, the input to network have format "NCHW", like in PyTorch. But using such network will cause wrong output, replicating the one channel on image, removing 2 others.

But is case of using "NHWC" format, like in the example here, everything works nice. But we need to have "NCHW" format and then translate back it to "NHWC".

Here is my example: https://gist.github.com/melgor/23679d140cde6fc372bb6ee0ad45df5b

I know that I could change the format in training code but I do not know why my implementation like in Lua does not work correctly.

fxia22 commented 7 years ago

If you only need a global transformation, you don't need to use Conv STN, you only need the AffineGridGen and STN. Conv STN is for implementation for this paper: https://arxiv.org/abs/1606.03558