jwyang / lr-gan.pytorch

Pytorch code for our ICLR 2017 paper "Layered-Recursive GAN for image generation"
151 stars 32 forks source link

STNM using pytorch official function #3

Closed godisboy closed 7 years ago

godisboy commented 7 years ago

Hi I implement STNM with the official function in Pytorch(grid_sample()). But the model can't converge.


class STNM(nn.Module):
    def __init__(self):
        super(STNM, self).__init__()  

    def forward(self, canvas, fgimg, fggrid, fgmask):
        #print('grid size: {} img_size: {}'.format(fggrid.size(), fgimg.size()))
        mask = F.grid_sample(fgmask, fggrid)
        fg = F.grid_sample(fgimg, fggrid)
        #torch.addcmul(tensor, value=1, tensor1, tensor2, out=None) → Tensor
        tmp1 = torch.FloatTensor(fg.size(0), fg.size(1), fg.size(2), fg.size(3))
        torch.addcmul(tmp1, mask.data, fg.data) 

        ng_mask = -1*mask 
        out = torch.add(ng_mask, 1)
        tmp2 = torch.FloatTensor(out.size(0), out.size(1), out.size(2), out.size(3))
        torch.addcmul(tmp2, out.data, canvas.data)

        return Variable(tmp1+tmp2, requires_grad = True)
jwyang commented 7 years ago

Hi, @godisboy,

It seems that you did not initialize tmp1 and tmp2 to zero before doing torch.addcmul().

Also, I think you should use Variables mask and fg and canvas in torch.addcmul(), instead of .data. Otherwise, the gradient cannot propagate back to canvas, fgimg, fggrid and fgmask.

And, did you try the original code using our STNM implementation?

godisboy commented 7 years ago

@jwyang Thanks! I know the problem. torch.addcmul() only support torch.Floattensor . So I need to write the backward pass in the implement. And I will try your code soon.