getkeops / keops

KErnel OPerationS, on CPUs and GPUs, with autodiff and without memory overflows
https://www.kernel-operations.io
MIT License
1.03k stars 65 forks source link

TypeError in backward pass #293

Closed Giodiro closed 1 year ago

Giodiro commented 1 year ago

Hi, I'm trying out the new release to finally get falkon up to date! I'm encountering a small error with the autodiff part of pykeops with the following version

>>> pykeops.__version__
'2.1.1'

A not so minimal reproducing example is

import torch
from pykeops.torch import Genred

X1 = torch.randn(20, 15).cuda().requires_grad_()
X2 = torch.randn(5, 15).cuda().requires_grad_()
v = torch.randn(5, 120).cuda().requires_grad_()
sigma = torch.randn(15).cuda().requires_grad_()

formula = 'Exp(SqDist(x1 / s, x2 / s) * IntInv(-2)) * v'
aliases = [
    'x1 = Vi(%d)' % (X1.shape[1]),
    'x2 = Vj(%d)' % (X2.shape[1]),
    'v = Vj(%d)' % (v.shape[1]),
    's = Pm(%d)' % (sigma.shape[0])
]
other_vars = [sigma]

fn = Genred(formula, aliases, reduction_op="Sum", axis=1)
out = fn(X1, X2, v, *other_vars, backend="GPU_1D")
print(out.shape)

grad = torch.autograd.grad(out.sum(), [X1])
print(grad[0].shape)

which fails at the torch.autograd.grad line with

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/anaconda3/envs/torch_111/lib/python3.9/site-packages/torch/autograd/__init__.py", line 275, in grad
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/anaconda3/envs/torch_111/lib/python3.9/site-packages/torch/autograd/function.py", line 253, in apply
    return user_fn(self, *args)
  File "/anaconda3/envs/torch_111/lib/python3.9/site-packages/pykeops/torch/generic/generic_red.py", line 279, in backward
    grad = genconv(
  File "/anaconda3/envs/torch_111/lib/python3.9/site-packages/pykeops/torch/generic/generic_red.py", line 78, in forward
    myconv = keops_binder["nvrtc" if tagCPUGPU else "cpp"](
  File "/anaconda3/envs/torch_111/lib/python3.9/site-packages/keopscore/utils/Cache.py", line 68, in __call__
    obj = self.cls(*args)
  File "/anaconda3/envs/torch_111/lib/python3.9/site-packages/pykeops/common/keops_io/LoadKeOps_nvrtc.py", line 15, in __init__
    super().__init__(*args, fast_init=fast_init)
  File "/anaconda3/envs/torch_111/lib/python3.9/site-packages/pykeops/common/keops_io/LoadKeOps.py", line 18, in __init__
    self.init(*args)
  File "/anaconda3/envs/torch_111/lib/python3.9/site-packages/pykeops/common/keops_io/LoadKeOps.py", line 126, in init
    ) = get_keops_dll(
  File "/anaconda3/envs/torch_111/lib/python3.9/site-packages/keopscore/utils/Cache.py", line 27, in __call__
    self.library[str_id] = self.fun(*args)
  File "/anaconda3/envs/torch_111/lib/python3.9/site-packages/keopscore/get_keops_dll.py", line 103, in get_keops_dll_impl
    chk = Chunk_Mode_Constants(red_formula)
  File "/anaconda3/envs/torch_111/lib/python3.9/site-packages/keopscore/mapreduce/Chunk_Mode_Constants.py", line 38, in __init__
    self.fun_postchunk = formula.post_chunk_formula(self.nminargs)
  File "/anaconda3/envs/torch_111/lib/python3.9/site-packages/keopscore/formulas/Operation.py", line 202, in post_chunk_formula
    args.append(child.post_chunk_formula(ind))
  File "/anaconda3/envs/torch_111/lib/python3.9/site-packages/keopscore/formulas/Operation.py", line 202, in post_chunk_formula
    args.append(child.post_chunk_formula(ind))
  File "/anaconda3/envs/torch_111/lib/python3.9/site-packages/keopscore/formulas/Operation.py", line 202, in post_chunk_formula
    args.append(child.post_chunk_formula(ind))
  [Previous line repeated 3 more times]
  File "/anaconda3/envs/torch_111/lib/python3.9/site-packages/keopscore/formulas/Chunkable_Op.py", line 56, in post_chunk_formula
    return type(self)(
TypeError: __init__() got an unexpected keyword argument 'params'

I tried changing the part where it fails in Chunkable_Op.py with a check on self.params == () like in the chunked_version function in the same class, and it seems to be working. I can submit a small patch for this, but since I don't really understand what the code is doing maybe let me know if just adding the check is fine!

Thanks for the great work you put into KeOps! Giacomo

de-gozaru commented 1 year ago

Hi @Giodiro @joanglaunes

Did you succeed to solve this issue, I'm using the latest version '2.1.2' but still getting the same error?

Thank you for your help!

cc: @djsutherland @dvolgyes @fradav

joanglaunes commented 1 year ago

Hello @Giodiro , @de-gozaru I am really sorry because I looked at this issue back in February but forgot to do the merge. I think it is solved now in main, could you check ?

Giodiro commented 1 year ago

Hi @joanglaunes, Fix works great, thanks a lot :100: