getkeops / keops

KErnel OPerationS, on CPUs and GPUs, with autodiff and without memory overflows
https://www.kernel-operations.io
MIT License
1.04k stars 64 forks source link

Unstable autograd for logsumexp if a GPU is detected #216

Open bgalerne opened 2 years ago

bgalerne commented 2 years ago

Hi,

We recently encountered a strange behavior regarding autograd of logsumexp involving distances between vectors having a large dimension.

The following code is stable when running on Google Colab without GPU and unstable when running on Google Colab with a GPU. This was also observed on a private GPU server.

The problem starts to occur for patch_size >= 7. Computing logsumexp is OK but the output of the autograd call is non consistent (and starts to be 'nan' for patch_size >= 8).

Do you have a solution regarding this problem ?

!pip install pykeops[colab] > install.log
import torch
from pykeops.torch import LazyTensor

nptsx = 2
nptsy = 2
patch_size = 7 # OK for 6 with or without GPU, unstable for 7 with GPU

torch.manual_seed(0)
x=torch.rand(nptsx, 3*patch_size**2).double()
x.requires_grad = True
y=torch.rand(nptsy, 3*patch_size**2).double()

# compute matrix of square distances between patches
x_i = LazyTensor(x[:,None,:])
y_j = LazyTensor(y[None,:,:])
D_ij = ((x_i-y_j)**2).sum(2)
print(D_ij)
eps = 1e-2
S = (-1/eps*D_ij).logsumexp(1)
print("S: ", S)
xgrad = torch.autograd.grad(S,x,torch.ones_like(S))[0]
print(xgrad)

Best,

Bruno

bgalerne commented 2 years ago

Output with a GPU

KeOps LazyTensor
    formula: Sum(Square((Var(0,147,0) - Var(1,147,1))))
    shape: (2, 2)
[pyKeOps] Initializing build folder for dtype=float64 and lang=torch in /root/.cache/pykeops-1.5-cpython-37 ... done.
[pyKeOps] Compiling libKeOpstorchcede5e5dca in /root/.cache/pykeops-1.5-cpython-37:
       formula: Max_SumShiftExp_Reduction((Var(2,1,2) * Sum(Square((Var(0,147,0) - Var(1,147,1))))),0)
       aliases: Var(0,147,0); Var(1,147,1); Var(2,1,2); 
       dtype  : float64
... 
[pyKeOps] Compiling pybind11 template libKeOps_template_660bc304e8 in /root/.cache/pykeops-1.5-cpython-37 ... done.
Done.
S:  tensor([[-2686.1923],
        [-2400.1004]], dtype=torch.float64, grad_fn=<UnsqueezeBackward0>)
[pyKeOps] Compiling libKeOpstorch52879b9bb0 in /root/.cache/pykeops-1.5-cpython-37:
       formula: Grad_WithSavedForward(Max_SumShiftExp_Reduction((Var(2,1,2) * Sum(Square((Var(0,147,0) - Var(1,147,1))))),0), Var(0,147,0), Var(3,2,0), Var(4,2,0))
       aliases: Var(0,147,0); Var(1,147,1); Var(2,1,2); Var(3,2,0); Var(4,2,0); 
       dtype  : float64
... 
Done.
tensor([[-3.3073e+69, -1.7538e+70,  2.3040e+70, -1.2908e+68,  1.2091e+70,
         -9.1943e+69, -7.0379e+69, -9.3287e+69, -6.5520e+69, -5.2037e+68,
          7.5846e+69, -8.3046e+69,  1.7066e+70,  1.6554e+70,  1.2292e+70,
         -2.4784e+69, -1.8484e+70, -1.1473e+70,  1.5596e+70,  1.0738e+69,
          2.4893e+68, -2.5012e+69, -4.8123e+69, -1.1233e+70,  8.7421e+69,
         -4.1736e+68, -1.2503e+70,  1.6011e+70, -2.3823e+69,  5.0466e+69,
          1.6463e+70,  5.6456e+68,  1.8801e+70,  9.0410e+69,  1.7732e+70,
          6.8466e+69, -3.7261e+69, -2.6157e+70,  8.3878e+69,  2.2836e+69,
         -1.1315e+70,  1.3751e+70,  8.9140e+69,  2.9410e+70,  2.4293e+70,
          1.9628e+70, -1.3085e+70, -1.4258e+70,  1.9236e+70,  9.0734e+69,
         -1.8883e+70, -1.4469e+69,  1.3124e+69,  6.2376e+69, -1.9734e+70,
          1.3321e+70, -8.1674e+69,  2.3680e+70,  7.8060e+69,  1.3780e+70,
         -4.1810e+69,  3.5456e+68,  3.8414e+69,  2.8213e+69, -1.8976e+70,
          5.3557e+69, -1.0873e+70, -1.8532e+70, -2.1653e+70, -2.5971e+69,
          1.9953e+70,  1.9847e+70,  6.2082e+69, -2.6189e+70, -2.3619e+70,
         -1.5880e+70, -1.2907e+70,  2.3791e+70, -1.2900e+70,  2.0948e+69,
         -5.1930e+68, -2.2152e+70,  9.4493e+68,  2.7629e+70,  8.0421e+69,
          1.4308e+70,  1.6030e+70,  3.0682e+69,  8.4344e+68,  2.2563e+70,
         -5.4417e+68,  7.5142e+69, -2.9999e+69,  3.7686e+69,  9.7318e+69,
          1.7167e+70, -5.8476e+69, -6.8168e+69,  3.2915e+69,  2.2366e+68,
         -2.6016e+70, -3.7840e+67,  1.7448e+70,  2.0973e+69,  5.6880e+69,
          1.2031e+70,  2.1183e+70,  1.2865e+70, -6.3440e+69,  1.1697e+70,
          1.2451e+70, -2.0575e+70,  2.3816e+70, -2.2513e+69, -1.5969e+70,
          2.1380e+70, -4.8573e+69, -3.0703e+69,  1.0697e+70, -2.9794e+69,
         -1.6728e+70, -5.9956e+69,  3.2932e+69, -1.0133e+70,  1.6048e+70,
          2.3085e+70,  4.6225e+69, -1.9825e+69,  6.1013e+69,  1.9762e+70,
          1.0782e+70, -1.7155e+70,  7.1880e+69, -1.0524e+70, -1.7380e+70,
          7.0446e+69, -1.1225e+70,  6.9670e+69,  9.8330e+69, -8.6747e+69,
         -2.4548e+70, -1.5111e+70, -1.1121e+70, -1.0496e+70,  2.1223e+70,
         -6.5989e+69, -2.4981e+70],
        [-1.1919e+45,  3.6328e+44,  1.7767e+45, -1.3123e+43,  2.0687e+45,
         -2.4350e+44, -3.1554e+44, -6.1131e+44,  7.7442e+43,  2.5382e+44,
          1.6312e+45,  1.4885e+45,  1.1246e+45, -5.7725e+44,  1.2731e+45,
          1.1745e+45,  1.6518e+45, -7.9693e+44, -3.7115e+44, -6.5152e+44,
         -9.4615e+44,  1.3779e+45, -8.8624e+44, -6.9234e+44,  1.1666e+45,
         -9.4307e+44, -1.2734e+45,  1.5700e+44, -1.1623e+45,  6.8456e+44,
          1.7693e+45, -2.7738e+44,  8.5884e+44,  3.0046e+44, -9.5295e+44,
         -1.9365e+45,  2.3149e+44, -1.2928e+45,  1.5799e+44, -5.7478e+44,
         -3.5195e+44,  4.8461e+41, -2.4702e+43, -1.1063e+45,  1.0543e+44,
          1.6353e+45,  6.2823e+44,  9.1021e+44,  2.9798e+44,  4.5107e+44,
         -2.5790e+44,  1.3480e+45, -8.3820e+44,  9.8445e+44, -9.3030e+44,
         -1.0129e+44,  1.7962e+44,  1.8097e+45,  2.6500e+44, -6.9995e+44,
         -1.7065e+44,  1.7233e+44,  5.7458e+44,  1.0546e+45, -4.2142e+44,
          2.3501e+44,  7.8733e+44, -3.5622e+44,  2.5569e+44,  6.1064e+44,
          1.4107e+45, -1.3398e+45,  5.4180e+44,  3.9955e+44, -7.8358e+44,
         -1.0494e+45,  7.4198e+43,  1.7858e+45, -7.3460e+44,  1.2554e+45,
          5.2701e+43,  8.4714e+44, -7.1167e+44, -7.3965e+44, -1.6453e+45,
          9.8070e+44, -1.4384e+45, -7.5876e+44,  1.8241e+44,  6.7892e+44,
          2.0790e+45, -4.0796e+44, -1.3643e+45, -1.3438e+45, -9.3239e+44,
          8.2751e+43, -8.1026e+44, -1.0787e+45,  1.4275e+45, -1.5380e+45,
          1.2948e+45, -3.0020e+44, -3.8843e+44,  8.2003e+44,  8.5217e+44,
         -1.8560e+45,  4.1832e+44,  2.7776e+44,  1.6685e+45,  5.2340e+43,
         -1.7412e+45,  1.8099e+45,  4.2789e+43,  9.4817e+44, -8.1658e+44,
         -9.1409e+44, -1.2236e+45,  3.3677e+44,  4.7142e+44, -8.5721e+44,
         -1.0938e+45, -1.4361e+45,  2.8566e+44, -2.1485e+44,  1.7449e+45,
          8.9249e+44, -4.8853e+44,  1.6606e+45,  1.7216e+45, -3.5510e+43,
          1.5162e+45, -2.3412e+44, -1.1731e+45, -1.3122e+44,  1.8530e+45,
          5.8373e+44, -2.1295e+45,  3.0512e+44,  1.7763e+44,  6.2432e+43,
         -1.3862e+45, -8.3772e+44,  3.6438e+44,  6.5859e+44, -1.0412e+45,
         -1.4082e+44,  4.9443e+44]], dtype=torch.float64)
bgalerne commented 2 years ago

Output without a GPU:

[pyKeOps]: Warning, cuda was detected, but driver API could not be initialized. Switching to cpu only.
KeOps LazyTensor
    formula: Sum(Square((Var(0,147,0) - Var(1,147,1))))
    shape: (2, 2)
[pyKeOps] Initializing build folder for dtype=float64 and lang=torch in /root/.cache/pykeops-1.5-cpython-37 ... done.
[pyKeOps] Compiling libKeOpstorchcede5e5dca in /root/.cache/pykeops-1.5-cpython-37:
       formula: Max_SumShiftExp_Reduction((Var(2,1,2) * Sum(Square((Var(0,147,0) - Var(1,147,1))))),0)
       aliases: Var(0,147,0); Var(1,147,1); Var(2,1,2); 
       dtype  : float64
... 
[pyKeOps] Compiling pybind11 template libKeOps_template_660bc304e8 in /root/.cache/pykeops-1.5-cpython-37 ... done.
Done.
S:  tensor([[-2686.1923],
        [-2400.1004]], dtype=torch.float64, grad_fn=<UnsqueezeBackward0>)
[pyKeOps] Compiling libKeOpstorch52879b9bb0 in /root/.cache/pykeops-1.5-cpython-37:
       formula: Grad_WithSavedForward(Max_SumShiftExp_Reduction((Var(2,1,2) * Sum(Square((Var(0,147,0) - Var(1,147,1))))),0), Var(0,147,0), Var(3,2,0), Var(4,2,0))
       aliases: Var(0,147,0); Var(1,147,1); Var(2,1,2); Var(3,2,0); Var(4,2,0); 
       dtype  : float64
... 
Done.
tensor([[-9.6417e+01, -7.3931e+01,  1.4955e+02, -2.1054e+01,  1.2163e+02,
         -6.6818e+01,  3.1270e+01, -7.4729e+01, -8.1298e+01,  5.6472e+01,
          8.4066e+01,  1.1906e+02,  1.4606e+02,  2.2131e-01,  1.2468e+02,
          1.6701e+00,  7.8881e+00, -1.4019e+02,  3.9031e+01, -5.4641e+01,
         -7.5271e+01, -6.1466e+01, -5.7936e+01, -4.2955e+01,  6.9799e+01,
          3.3495e+00, -1.5746e+02,  1.5235e+01,  3.2102e+01,  6.9212e+01,
          1.3762e+02, -2.8897e+01,  5.3557e+01,  8.1095e+01, -2.8242e+01,
          8.2460e+00,  1.0503e+02, -1.4260e+02,  3.4877e+00, -1.1906e+02,
         -5.4790e+01, -3.1101e+01,  3.5266e+01,  8.3180e+01,  1.4922e+02,
          1.1344e+02, -7.7595e+00, -5.5417e+01,  1.3432e+01, -1.8135e+01,
         -5.0666e+01, -4.9842e+01, -1.0975e+02,  7.0419e+01, -7.7922e+01,
         -2.4914e+01, -7.2740e+01,  1.6601e+02,  1.6076e+02, -4.0970e+01,
          2.5283e+01,  9.8999e+00,  1.1843e+02,  5.4442e+01, -1.0879e+02,
          4.4199e+01,  3.7105e+01, -5.2720e+01, -2.8572e+01, -1.2479e+02,
          1.3889e+02,  8.9015e+00, -7.6581e+01, -5.7680e+01, -5.1551e+01,
         -1.3709e+02, -1.1062e+01,  1.4491e+02, -1.9149e+01,  9.2994e+01,
          6.9077e+01, -5.3799e+01, -1.4848e+02,  4.0950e+01, -6.9476e+01,
          1.0875e+02, -1.0452e+01,  4.5503e+01, -8.6905e+01,  1.5906e+02,
          6.1008e+01, -1.2599e+02, -1.5265e+02, -1.1319e+02,  9.8491e+01,
          3.9379e+01, -1.1539e+02, -5.6797e+01,  9.6901e+01,  5.3329e+00,
         -7.1917e+01, -9.0313e+01, -4.8538e+00,  3.8907e+01,  1.7749e+02,
         -6.5811e+01,  7.9949e+01,  4.0259e+01,  1.0364e+02, -1.0989e+02,
         -5.4913e+01,  3.9201e+01, -2.4151e+00, -7.1565e+01, -7.7161e+01,
          6.3263e+01,  2.5537e+00,  5.9032e+01, -9.8104e+00, -1.2108e+02,
         -9.8279e+01, -4.0080e+01,  9.9115e+01, -1.1203e+02,  1.9022e+02,
          3.6998e+01,  7.0768e+00,  1.2707e+02,  8.5244e+01,  1.6677e+02,
          7.9116e+01, -5.1711e+01, -1.2007e+00, -2.3084e+00,  1.1179e+01,
         -2.3282e+01, -9.5516e+01,  4.9620e+01,  7.7109e+01, -9.2052e+01,
         -1.7753e+02, -1.5421e+02, -6.9881e+01, -3.3981e-01,  3.9333e+01,
          6.7621e+01, -3.0115e+01],
        [-9.6518e+01,  2.9418e+01,  1.4387e+02, -1.0627e+00,  1.6752e+02,
         -1.9719e+01, -2.5552e+01, -4.9504e+01,  6.2713e+00,  2.0555e+01,
          1.3210e+02,  1.2054e+02,  9.1068e+01, -4.6745e+01,  1.0309e+02,
          9.5109e+01,  1.3377e+02, -6.4535e+01, -3.0056e+01, -5.2760e+01,
         -7.6619e+01,  1.1158e+02, -7.1767e+01, -5.6066e+01,  9.4470e+01,
         -7.6369e+01, -1.0312e+02,  1.2714e+01, -9.4121e+01,  5.5435e+01,
          1.4328e+02, -2.2462e+01,  6.9549e+01,  2.4331e+01, -7.7169e+01,
         -1.5682e+02,  1.8746e+01, -1.0469e+02,  1.2794e+01, -4.6546e+01,
         -2.8501e+01,  3.9244e-02, -2.0003e+00, -8.9587e+01,  8.5376e+00,
          1.3243e+02,  5.0874e+01,  7.3709e+01,  2.4131e+01,  3.6527e+01,
         -2.0884e+01,  1.0916e+02, -6.7877e+01,  7.9721e+01, -7.5335e+01,
         -8.2026e+00,  1.4545e+01,  1.4655e+02,  2.1459e+01, -5.6681e+01,
         -1.3819e+01,  1.3955e+01,  4.6529e+01,  8.5397e+01, -3.4127e+01,
          1.9031e+01,  6.3758e+01, -2.8847e+01,  2.0706e+01,  4.9449e+01,
          1.1424e+02, -1.0849e+02,  4.3875e+01,  3.2356e+01, -6.3455e+01,
         -8.4981e+01,  6.0085e+00,  1.4461e+02, -5.9487e+01,  1.0166e+02,
          4.2677e+00,  6.8602e+01, -5.7631e+01, -5.9897e+01, -1.3324e+02,
          7.9417e+01, -1.1648e+02, -6.1444e+01,  1.4772e+01,  5.4979e+01,
          1.6836e+02, -3.3036e+01, -1.1048e+02, -1.0882e+02, -7.5505e+01,
          6.7012e+00, -6.5615e+01, -8.7350e+01,  1.1560e+02, -1.2455e+02,
          1.0485e+02, -2.4310e+01, -3.1455e+01,  6.6406e+01,  6.9009e+01,
         -1.5030e+02,  3.3875e+01,  2.2493e+01,  1.3511e+02,  4.2385e+00,
         -1.4101e+02,  1.4656e+02,  3.4650e+00,  7.6782e+01, -6.6126e+01,
         -7.4023e+01, -9.9087e+01,  2.7272e+01,  3.8176e+01, -6.9417e+01,
         -8.8579e+01, -1.1629e+02,  2.3132e+01, -1.7399e+01,  1.4130e+02,
          7.2274e+01, -3.9561e+01,  1.3447e+02,  1.3942e+02, -2.8756e+00,
          1.2279e+02, -1.8959e+01, -9.4997e+01, -1.0626e+01,  1.5005e+02,
          4.7270e+01, -1.7245e+02,  2.4709e+01,  1.4385e+01,  5.0557e+00,
         -1.1226e+02, -6.7838e+01,  2.9507e+01,  5.3332e+01, -8.4319e+01,
         -1.1403e+01,  4.0039e+01]], dtype=torch.float64)
joanglaunes commented 2 years ago

Hello Bruno ! The bug comes from the special "chunked" computation method that is implemented for high dimensions. In your case the use of 7x7 image patches brings a dim=147 in the formula, which enables the mode. We implemented it because the default method becomes too slow when dimensions get large, due to memory use in CUDA kernels. I need to dig into it to see why it fails here, but a temporary fix is to disable this mode, which can be done by replacing S = (-1/eps*D_ij).logsumexp(1) by S = (-1/eps*D_ij).logsumexp(1, enable_chunks=False)

bgalerne commented 2 years ago

Hi Joan,

Ok thank you ! We were not aware of this option.

Best,

Bruno

bgalerne commented 2 years ago

Hi Joan,

Is the commit a fix for the issue that will decrease performance?

Best,

Bruno

joanglaunes commented 2 years ago

Hello Bruno, Sorry I should have said something about this commit. In fact the chunked mode I was talking about cannot be used for the gradient pass of the formula that you use. We hope to implement it in the future but it is not done yet. So the bug was caused by the fact that the chunked mode was improperly automatically set to True for the gradient. This is what I have corrected with this commit. It means that in fact the trick with the enable_chunks=False option will be no longer needed, but the computation mode will be the same for the gradient, i.e. without the chunked mode. Having said that, a few more remarks:

Can you tell us if the computation is ok now for your experiments ? Also, what patch size ideally you intend to use in fact ?

Best, Joan