getkeops / keops

KErnel OPerationS, on CPUs and GPUs, with autodiff and without memory overflows
https://www.kernel-operations.io
MIT License
1.03k stars 65 forks source link

Issues with real/complex and onehot/ifelse. #268

Open amic-github opened 1 year ago

amic-github commented 1 year ago

Dear developers of keops,

I have found some strange (unexpected behavior ?) when using complex vector parameters together with one_hot and real LazyTensors.

The following code behaves as expected :

import torch
from pykeops.torch import Vi, Pm

x=torch.tensor([0.,4.,3.,2.]).reshape(4,1)
xi=Vi(x)
y=xi.one_hot(6)
weights=Pm([12., 2., 15., 1., 1., -1.])
(y*weights).sum(1)

and returns

tensor([[12.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  1.,  0.],
        [ 0.,  0.,  0.,  1.,  0.,  0.],
        [ 0.,  0., 15.,  0.,  0.,  0.]])

But if I change the parameter tensor with complex values, I do not understand the results. The following code

weightz=Pm([12.j,2.j,15.j,1.j,1.j,-1.j])
(y*weightz).sum(1)

produces the output

tensor([[ 0.+12.j,  0.+0.j,  0.+0.j,  0.+0.j,  0.+0.j,  0.+0.j],
        [ 0.+0.j,  0.+0.j,  0.+0.j,  0.+0.j, 12.+0.j,  0.+0.j],
        [ 0.+0.j,  0.+0.j,  0.+0.j,  0.+0.j,  0.+0.j,  0.+0.j],
        [ 0.+0.j,  0.+0.j, 12.+0.j,  0.+0.j,  0.+0.j,  0.+0.j]])

Changing y by y.real2complex() in the last line does not change anything.

Another (related ?) issue with one_hot is when the input contains negative entries. The following code

x=torch.tensor([0.,8.,-3.,2.]).reshape(4,1)
xi=Vi(x)
y=xi.one_hot(6)
(y*weights).sum(1)

produces the output

tensor([[12.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  1.,  0.,  0.,  0.],
        [ 0.,  0., 15.,  0.,  0.,  0.]])

(while I would have imagined the third line to be zero). And more unexpected, changing the last line by (y|weights).sum(1) causes the jupyter kernel to break.

Finally, another strange behavior appears with the use of x.ifelse(a,b) when x is real while a and b are complex. The following code

x=torch.tensor([0.,-1.,1.]).reshape(3,1)
xi=Vi(x)
y=Pm([2,-3j])
xi.ifelse(y,-y).sum(1)

produces the output

tensor([[ 2.,  0.,  0.,  0.],
        [-2.,  0.,  0.,  0.],
        [ 2.,  0.,  0.,  0.]])

which in my opinion does not even have the good shape, as the same code with y=Pm([2,-3]) produces, as expected, the output

tensor([[ 2., -3.],
        [-2.,  3.],
        [ 2., -3.]])

I think that I may have missed the way complex LazyTensor are supposed to work, but maybe it is just not supported for the moment.

antoinediez commented 1 year ago

Dear Amic and dear KeOps developers,

I have tried your examples this morning. I think I understood the behaviour for the second issue with ifelse : basically, it seems that KeOps treats the complex LazyTensors as 2D tensors (one component for the real part and one for the imaginary part). This explains the dimension issue in your example.

For instance

x=torch.tensor([[-1.,2.,4,5],[3.,42.,-6,-7],[5,6,8,9]])
xi = LazyTensor(x[:,None,:])

y = xi.ifelse(torch.tensor([3+2j,7+1j]),torch.tensor([20+1j,92]))

y.sum(1)

gives

tensor([[20.,  2.,  7.,  1.],
        [ 3.,  2., 92.,  0.],
        [ 3.,  2.,  7.,  1.]])

And this is equivalent to writing


x=torch.tensor([[-1.,2.,4.,5.],[3.,42.,-6.,-7.],[5.,6.,8.,9.]])
xi = LazyTensor(x[:,None,:])

y = xi.ifelse(torch.tensor([3.,2.,7.,1.]),torch.tensor([20.,1.,92.,0.]))

y.sum(1)

Note also that

x=torch.tensor([[-1.,2.],[3.,42.],[5.,6.]])
xi = LazyTensor(x[:,None,:])

y = xi.ifelse(torch.tensor([3+2j,7+1j]),torch.tensor([20+1j,92]))

produces the dimension error

ValueError                                Traceback (most recent call last)
Untitled-8.ipynb Cellule 6 in <cell line: 4>()
      [1](vscode-notebook-cell:Untitled-8.ipynb?jupyter-notebook#W5sdW50aXRsZWQ%3D?line=0) x=torch.tensor([[-1.,2.],[3.,42.],[5.,6.]])
      [2](vscode-notebook-cell:Untitled-8.ipynb?jupyter-notebook#W5sdW50aXRsZWQ%3D?line=1) xi = LazyTensor(x[:,None,:])
----> [4](vscode-notebook-cell:Untitled-8.ipynb?jupyter-notebook#W5sdW50aXRsZWQ%3D?line=3) y = xi.ifelse(torch.tensor([3+2j,7+1j]),torch.tensor([20+1j,92]))

File /opt/homebrew/lib/python3.9/site-packages/pykeops/common/lazy_tensor.py:1413, in GenericLazyTensor.ifelse(self, other1, other2)
   1405 def ifelse(self, other1, other2):
   1406     r"""
   1407     Element-wise if-else function - a ternary operation.
   1408 
   (...)
   1411     a and b may be fixed integers or floats, or other LazyTensors.
   1412     """
-> 1413     return self.ternary(other1, other2, "IfElse", dimcheck="sameor1")

File /opt/homebrew/lib/python3.9/site-packages/pykeops/common/lazy_tensor.py:605, in GenericLazyTensor.ternary(self, other1, other2, operation, dimres, dimcheck, opt_arg)
    603 elif dimcheck == "sameor1":
    604     if not same_or_one_test(self.ndim, other1.ndim, other2.ndim):
--> 605         raise ValueError(
    606             "Operation {} expects inputs of the same dimension or dimension 1. ".format(
    607                 operation
    608             )
    609             + "Received {}, {} and {}.".format(
...
    612         )
    614 elif dimcheck != None:
    615     raise ValueError("incorrect dimcheck keyword in binary operation")

ValueError: Operation IfElse expects inputs of the same dimension or dimension 1. Received 2, 4 and 4.

which is expected I think (basically the 2D complex tensors are seen as 4D tensors and not 2D as it may be believed)?

However, I don't know if it is the behaviour expected for ifelse, so could the developers confirm this? Or may I suggest to add it in the documentation?

Also maybe this is the reason for the other issues raised by Amic.

Best, Antoine

amic-github commented 1 year ago

Dear developers,

Apparently, the problem is with Pm (when called with a list of numbers) and not with one_hot, as the following code

x=torch.tensor([0.,4.,3.,2.]).reshape(4,1)
xi=Vi(x)
y=xi.one_hot(6)
weights=torch.tensor([12., 2.+1j, 15.-1j, 1., 1., -1.])
(y*weights).sum(1)

gives the expected result. Actually, the error does not occur when the input of Pm is already a torch tensor : the following code

buggyPm=Pm([12.j, 2.+1j, 15.-1j, 1., 1., -1.])
goodPm=Pm(torch.tensor([12.j, 2.+1j, 15.-1j, 1., 1., -1.]))
y=Vi(torch.tensor([1.]).reshape(1,1))
(y*buggyPm).sum(1),(y*goodPm).sum(1)

returns

(tensor([[ 0.+12.j, 12.+0.j, 12.+0.j,  0.+0.j, 12.+0.j,  0.+0.j]]),
 tensor([[ 0.+12.j,  2.+1.j, 15.-1.j,  1.+0.j,  1.+0.j, -1.+0.j]]))

The strange behavior of one_hot with negative entries is not related with this issue with Pm.

Finally, the behavior of a.ifelse(b,c) seems indeed to first replace b and c by real versions (with twice more entries), as mentioned by Antoine (the strange behavior that I had in my example was also a bad behavior of Pm).