getkeops / keops

KErnel OPerationS, on CPUs and GPUs, with autodiff and without memory overflows
https://www.kernel-operations.io
MIT License
1.03k stars 65 forks source link

keops_tensordot failed with dtype complex64 #233

Open kuangllbnu opened 2 years ago

kuangllbnu commented 2 years ago

I write a demo following the example but inplacing the dtype with complex64.

import numpy as np
import torch
from pykeops.torch import LazyTensor ComplexLazyTensor
M, N = 2, 10

x = torch.randn(M, 2, 3, 2, 2, 4, dtype=torch.complex64)
y = torch.randn(N, 2, 4, 2, 3, 2, 3, dtype=torch.complex64)
xshape, yshape = x.shape[1:], y.shape[1:]
A = ComplexLazyTensor(x.reshape(M, 1, int(np.array((xshape)).prod())))
B = ComplexLazyTensor(y.reshape(1, N, int(np.array(yshape).prod())))
f_keops = A.keops_tensordot(B, xshape, yshape, (4, 0, 2), (1, 4, 2), )
f_keops.sum_reduction(dim=1)

But failed with errors as follows

 Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/opt/conda/lib/python3.8/site-packages/pykeops/common/lazy_tensor.py", line 1793, in sum_reduction
    return self.reduction("Sum", axis=axis, dim=dim, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/pykeops/common/lazy_tensor.py", line 746, in reduction
    return res()
  File "/opt/conda/lib/python3.8/site-packages/pykeops/common/lazy_tensor.py", line 2499, in __call__
    res = super().__call__(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/pykeops/common/lazy_tensor.py", line 930, in __call__
    return self.callfun(*args, *self.variables, **self.kwargs)
  File "/opt/conda/lib/python3.8/site-packages/pykeops/torch/generic/generic_red.py", line 624, in __call__
    out = GenredAutograd.apply(
  File "/opt/conda/lib/python3.8/site-packages/pykeops/torch/generic/generic_red.py", line 78, in forward
    myconv = keops_binder["nvrtc" if tagCPUGPU else "cpp"](
  File "/opt/conda/lib/python3.8/site-packages/keopscore/utils/Cache.py", line 68, in __call__
    obj = self.cls(*args)
  File "/opt/conda/lib/python3.8/site-packages/pykeops/common/keops_io/LoadKeOps_nvrtc.py", line 15, in __init__
    super().__init__(*args, fast_init=fast_init)
  File "/opt/conda/lib/python3.8/site-packages/pykeops/common/keops_io/LoadKeOps.py", line 18, in __init__
    self.init(*args)
  File "/opt/conda/lib/python3.8/site-packages/pykeops/common/keops_io/LoadKeOps.py", line 126, in init
    ) = get_keops_dll(
  File "/opt/conda/lib/python3.8/site-packages/keopscore/utils/Cache.py", line 27, in __call__
    self.library[str_id] = self.fun(*args)
  File "/opt/conda/lib/python3.8/site-packages/keopscore/get_keops_dll.py", line 93, in get_keops_dll_impl
    red_formula = GetReduction(red_formula_string, aliases)
  File "/opt/conda/lib/python3.8/site-packages/keopscore/formulas/GetReduction.py", line 27, in __new__
    reduction = eval(red_formula_string, globals(), aliases_dict)
  File "<string>", line 1, in <module>
  File "/opt/conda/lib/python3.8/site-packages/keopscore/formulas/maths/TensorDot.py", line 65, in __init__
    assert fa.dim == prod(dimsfa)
AssertionError

Does the Complex64 data type not supported ? Thanks

joanglaunes commented 2 years ago

Hello @kuangllbnu , In fact the problem comes from the use of the tensordot operation with complex dtype. Currently our tensordot operation works only for real valued tensors, complex valued tensors are treated as real tensors with twice the dimensions, and the tensordot operation that would be performed if the parameters were compatible would not be a complex tensordot. Operations that are specifically implemented for complex dtype are listed here under "Operations involving complex numbers". So this is definitely an issue, we need to modify PyKeOps so that it clearly forbids the use of operations such as tensordot for complex valued tensors, or implement the complex valued counterpart.