getkeops / keops

KErnel OPerationS, on CPUs and GPUs, with autodiff and without memory overflows
https://www.kernel-operations.io
MIT License
1.05k stars 64 forks source link

backward contiguous error during permute and cat operation #149

Closed hbgtjxzbbx closed 2 years ago

hbgtjxzbbx commented 3 years ago

Hello :) Thanks for the great library ! I met contiguous issues during backward call. Any suggestion would be super helpful, thanks! One is cause by "permute" operation after kernel reduction and the other is caused by "cat" two kernel reduction results.

both output:

RuntimeError: [Keops] Arg at position 3: is not contiguous. Please provide 'contiguous' dara array, 
as KeOps does not support strides. If you're getting this error in the 'backward' pass of a code using 
torch.sum() on the output of a KeOps routine, you should consider replacing 'a.sum()' 
with 'torch.dot(a.view(-1), torch.ones_like(a).view(-1))'. 

Here is the example code

from pykeops.torch import  LazyTensor
import torch

a = torch.rand(2, 1000, 5)
a.requires_grad = True
b = torch.rand(2, 1000, 5)
c1 = torch.rand(2, 1000,5)
c2 = torch.rand(2, 1000,5)

a_i = LazyTensor(a[:,:,None])
b_j = LazyTensor(b[:,None])

dist = a_i.sqdist(b_j)
kernel = dist.exp()
d1 = kernel @ c1
d2 = kernel @ c2

# failed case 1
d_permute = d1.permute(0,2,1)
d_permute.contiguous().clone().mean().backward()

# failed case 2
d_cat = torch.cat([d1,d2],2)
d_cat.mean().backward()
jeanfeydy commented 3 years ago

Hi @hbgtjxzbbx ,

Thanks for your important bug report! As of today, a simple workaround is to proceed as follows:

from pykeops.torch import  LazyTensor
import torch

a = torch.rand(2, 1000, 5)
a.requires_grad = True
b = torch.rand(2, 1000, 5)
c1 = torch.rand(2, 1000,5)
c2 = torch.rand(2, 1000,5)

a_i = LazyTensor(a[:,:,None])
b_j = LazyTensor(b[:,None])

dist = a_i.sqdist(b_j)
kernel = dist.exp()

class ContiguousBackward(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input):
        return input

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.contiguous()

# Test case 1
d1 = ContiguousBackward.apply(kernel @ c1)
d_permute = d1.permute(0,2,1).mean().backward()

# Test case 2
d1 = ContiguousBackward.apply(kernel @ c1)
d2 = ContiguousBackward.apply(kernel @ c2)
d_cat = torch.cat([d1,d2],2)
d_cat.mean().backward()

The ContiguousBackward wrapper ensures that all tensors that are passed to the backward KeOps function are indeed contiguous. Going forward, I believe that we could adopt this behavior by default... Back in 2019, we opted against silently casting the input to KeOps routines as contiguous tensors. The idea was to make users aware of the (small) cost of these transposition operations, thus allowing them to best profile their programs. With hindsight, however, we were probably over-thinking the problem: the overhead of the .contiguous() operation is negligible in most (all?) settings. What do you think @bcharlier, @joanglaunes ?

Best regards, And see you soon! Jean

hbgtjxzbbx commented 3 years ago

Jean, Thanks for fixing this :)

bcharlier commented 3 years ago

Hi all,

to be more specific:

jeanfeydy commented 3 years ago

Hi @bcharlier ,

Thanks for the details :-) Would you be fine with just casting all tensors as contiguous prior to the generic reductions? If tensors are already contiguous, this would do nothing; and otherwise, in my experience, the cost of the "transpose" is always negligible anyway. Do you see a situation where this could be a problem?

See you soon, Jean

bcharlier commented 3 years ago

In fact, you wil pay a copy (not a simple view change) of the non-contiguous variable. Usually, this is not a major problem but when M or N are large, it can be an issue (in term of memory size at least).

That being said, if 99,9% of the user apply the .contiguous() method by hand to fix that, we can do it automatically anyway :)

joanglaunes commented 3 years ago

Hello @hbgtjxzbbx , We decided to apply a .contiguous() for all non contiguous input tensors in the PyTorch bindings, with a warning in the forward pass (since it causes an extra copy that the user may want to avoid), and silently in the backward pass. The update has been merged into master so I am closing this issue now. Feel free to re-open if needed.