getkeops / keops

KErnel OPerationS, on CPUs and GPUs, with autodiff and without memory overflows
https://www.kernel-operations.io
MIT License
1.03k stars 65 forks source link

Error in Broadcast operation upgrading to v2 #246

Closed Giodiro closed 2 years ago

Giodiro commented 2 years ago

Hi,

Congratulations on releasing v2, the compilation speedups look great!

I've been working on updating falkon, and I think I can get rid of my custom branch since you've added inplace support which made me very happy:)

I've been encountering an issue with the backward pass for an rbf kernel with a different lengthscale for each dimension. The error only occurs when d is different than t, but I think it should not matter. Code to reproduce and the error are below

import torch
from pykeops.torch import Genred

m, n, d, t = 200, 200, 4, 3
X1 = torch.randn(m, d)
X2 = torch.randn(n, d)
v = torch.randn(n, t).requires_grad_()
g = torch.randn(d)
formula = 'Exp(SqDist(x1 / g, x2 / g) * IntInv(-2)) * v'
aliases = [
    'x1 = Vi(%d)' % (X1.shape[1]),
    'x2 = Vj(%d)' % (X2.shape[1]),
    'v = Vj(%d)' % (v.shape[1]),
    'g = Pm(%d)' % (g.shape[0]),
]
conv = Genred(formula, aliases, reduction_op='Sum', axis=1)
out = conv(X1, X2, v, g, out=None)
g_v = torch.autograd.grad(out.sum(), v)

which gives the following (truncated) exception:

~/miniconda3/envs/torch/lib/python3.9/site-packages/keopscore/formulas/GetReduction.py in <module>

~/miniconda3/envs/torch/lib/python3.9/site-packages/keopscore/formulas/autodiff/Grad_WithSavedForward.py in Grad_WithSavedForward(red_formula, v, gradin, f0)
      3 
      4 def Grad_WithSavedForward(red_formula, v, gradin, f0):
----> 5     return red_formula.DiffT(v, gradin, f0)

~/miniconda3/envs/torch/lib/python3.9/site-packages/keopscore/formulas/reductions/Sum_Reduction.py in DiffT(self, v, gradin, f0)
     43         from keopscore.formulas.autodiff import Grad
     44 
---> 45         return Sum_Reduction(Grad(self.formula, v, gradin), v.cat % 2)

~/miniconda3/envs/torch/lib/python3.9/site-packages/keopscore/formulas/autodiff/Grad.py in Grad(formula, v, gradin)
     21         cat = 1 - v.cat
     22         gradin = Var(ind, dim, cat)
---> 23     return formula.DiffT(v, gradin)

~/miniconda3/envs/torch/lib/python3.9/site-packages/keopscore/formulas/maths/Mult.py in DiffT(self, v, gradin)
     26         fa, fb = self.children
     27         if fa.dim == 1 and fb.dim > 1:
---> 28             return fa.DiffT(v, Scalprod(gradin, fb)) + fb.DiffT(v, fa * gradin)
     29         elif fb.dim == 1 and fa.dim > 1:
     30             return fa.DiffT(v, fb * gradin) + fb.DiffT(v, Scalprod(gradin, fa))

~/miniconda3/envs/torch/lib/python3.9/site-packages/keopscore/formulas/VectorizedScalarOp.py in DiffT(self, v, gradin)
     35         if len(self.children) == 1:
     36             derivatives = (derivatives,)
---> 37         return sum(f.DiffT(v, gradin * df) for f, df in zip(self.children, derivatives))
     38 
     39     def Derivative(self):

~/miniconda3/envs/torch/lib/python3.9/site-packages/keopscore/formulas/VectorizedScalarOp.py in <genexpr>(.0)
     35         if len(self.children) == 1:
     36             derivatives = (derivatives,)
---> 37         return sum(f.DiffT(v, gradin * df) for f, df in zip(self.children, derivatives))
     38 
     39     def Derivative(self):

~/miniconda3/envs/torch/lib/python3.9/site-packages/keopscore/formulas/maths/Mult.py in DiffT(self, v, gradin)
     30             return fa.DiffT(v, fb * gradin) + fb.DiffT(v, Scalprod(gradin, fa))
     31         else:
---> 32             return fa.DiffT(v, fb * gradin) + fb.DiffT(v, fa * gradin)
     33 
     34     # parameters for testing the operation (optional)

~/miniconda3/envs/torch/lib/python3.9/site-packages/keopscore/formulas/maths/Scalprod.py in DiffT(self, v, gradin)
     34     def DiffT(self, v, gradin):
     35         fa, fb = self.children
---> 36         return gradin * (fa.DiffT(v, fb) + fb.DiffT(v, fa))
     37 
     38     def initacc_chunk(self, acc):

~/miniconda3/envs/torch/lib/python3.9/site-packages/keopscore/formulas/maths/Subtract.py in DiffT(self, v, gradin)
     23     def DiffT(self, v, gradin):
     24         fa, fb = self.children
---> 25         return fa.DiffT(v, gradin) - fb.DiffT(v, gradin)
     26 
     27     # parameters for testing the operation (optional)

~/miniconda3/envs/torch/lib/python3.9/site-packages/keopscore/formulas/maths/Divide.py in DiffT(self, v, gradin)
     33             ) / Square(fb)
     34         else:
---> 35             return (fa.DiffT(v, fb * gradin) - fb.DiffT(v, fa * gradin)) / Square(fb)
     36 
     37     # parameters for testing the operation (optional)

~/miniconda3/envs/torch/lib/python3.9/site-packages/keopscore/formulas/Operation.py in __truediv__(self, other)
    112         from keopscore.formulas.maths.Divide import Divide
    113 
--> 114         return Divide(self, int2Op(other))
    115 
    116     def __rtruediv__(self, other):

~/miniconda3/envs/torch/lib/python3.9/site-packages/keopscore/formulas/maths/Divide.py in Divide(arg0, arg1)
     43 def Divide(arg0, arg1):
     44     if isinstance(arg0, Zero):
---> 45         return Broadcast(arg0, arg1.dim)
     46     elif isinstance(arg1, Zero):
     47         KeOps_Error("division by zero")

~/miniconda3/envs/torch/lib/python3.9/site-packages/keopscore/formulas/Operation.py in Broadcast(arg, dim)
    231         return SumT(arg, dim)
    232     else:
--> 233         KeOps_Error("dimensions are not compatible for Broadcast operation")

~/miniconda3/envs/torch/lib/python3.9/site-packages/keopscore/utils/misc_utils.py in KeOps_Error(message, show_line_number)
     26         frameinfo = getframeinfo(currentframe().f_back)
     27         message += f" (error at line {frameinfo.lineno} in file {frameinfo.filename})"
---> 28     raise ValueError(message)
     29 
     30 

ValueError: [KeOps] Error : dimensions are not compatible for Broadcast operation (error at line 233 in file /home/giacomo/miniconda3/envs/torch/lib/python3.9/site-packages/keopscore/formulas/Operation.py)

Thanks again for your work!

joanglaunes commented 2 years ago

Hello @Giodiro , Sorry for the delay in answering. I just fixed the issue in the main branch now, it was an error in the formula for the gradient of the divide operation, in the case of vector variables. If you can use the main branch in your framework, please let us know if it works ok on your side as well.

Giodiro commented 2 years ago

Thanks @joanglaunes , I'm testing it now and it seems to be working perfectly! Thanks a lot, Giacomo