getkeops / keops

KErnel OPerationS, on CPUs and GPUs, with autodiff and without memory overflows
https://www.kernel-operations.io
MIT License
1.04k stars 64 forks source link

( lazyvec1 - lazyvec1.t() ) @ lazyvec2 fails with: 'NoneType' object has no attribute 'replace' #211

Open joanglaunes opened 2 years ago

joanglaunes commented 2 years ago

Discussed in https://github.com/getkeops/keops/discussions/210

Originally posted by **tvercaut** February 3, 2022 Hi, Thanks for the nice work. I am not sure if I am using the library correctly as it is the first time I try it but the following behaviour looks counter-intuitive to me. Can someone tell me if it's a bug or a misunderstanding from my part? For two vectors `u` and `v` of size (Nx1) I would like to compute `(u_l - u_l.T) @ v` in a lazy fashion. Doing this directly with `u_l` being the lazy vector corresponding to the concrete vector `u`, and `v` being a concrete vector fails with the following error: ``` 'NoneType' object has no attribute 'replace' ``` The workaround I found was to create a second lazy vector from the transposed concrete vector `u.t()` however this also was counter-intuitive since `LazyTensor(u.t(),axis=1)` is reported to be of size 1x1xN rather than 1xN. The second workaround I found was to create the lazy vector from a 3D tensor of size 1xNx1. Let me know what I might be missing. Here is an example (also on [colab](https://colab.research.google.com/drive/15WAK3HtSQuriW2ohFuMSGFOEFTvPdEBr?usp=sharing)): ```python import torch print(torch.__version__) torchdevice = torch.device('cpu') if torch.cuda.is_available(): torchdevice = torch.device('cuda') print('Default GPU is ' + torch.cuda.get_device_name(torch.device('cuda'))) print('Running on ' + str(torchdevice)) !pip install pykeops[colab] import pykeops print('pykeops version:',pykeops.__version__) from pykeops.torch import LazyTensor # Generate 2 vector u and v u = torch.normal(0, 1, size=(5,1), device=torchdevice) v = torch.normal(0, 1, size=(5,), device=torchdevice) # Compute (u - u.T) @ v in a dense fashion P = u - u.t() Pv = P @ v print('Pv: ',Pv) # Compute (u - u.T) @ v in a lazy fashion u_l = LazyTensor(u,axis=0) print('u_l: ', u_l) # Using the transpose of the lazy tensor doesn't work here #print('u_l.t(): ', u_l.t()) #P_l = u_l - u_l.t() #print('P_l:', P_l) # We thus go for transposing the concrete vector and creating a new lazy tensor # However, it seems to fails when provided as a 2D matrix of size 1 x N #ut_l = LazyTensor(u.t(),axis=1) print('LazyTensor(u.t(),axis=1): ',LazyTensor(u.t(),axis=1)) ut_l = LazyTensor(u.t()[...,None]) print('ut_l: ', ut_l) P_l = u_l - ut_l print('P_l:', P_l) Pv_froml = P_l @ v print('Pv_froml: ',Pv_froml) # Putting this at the end just to highlight that this fails ( u_l - u_l.t() ) @ v ```
joanglaunes commented 2 years ago

Hello Tom,

Thanks for your interest in our library. The first error is a clear bug from the transpose method .t() in the LazyTensor class. I have now fixed it in the master branch. So in colab for example your script works now if you install the master branch version with

!pip install cmake==3.18
!pip install git+https://github.com/getkeops/keops.git@master

Thank you for noticing and mentioning this bug.

Now for the second problem you mention, it is not really a bug but rather an odd behavior due to the way we interpret dimensions in KeOps : when you convert a 2D tensor to LazyTensor the input shape is interpreted as (N,D) (sample size and dimension of each point/vector), so with the option axis=0 it will give a LazyTensor of shape (N,1,D) (the N dimension is put in first position), and with axis=1 it will give shape (1,N,D). In your case you input u.t() which is of shape (1,5), so it will interpret as N=1 and D=5 and with axis=0 it gives a (1,1,5) LazyTensor. I know this is really confusing in the specific use case you have... We already thought about changing this behaviour or adding another class similar to LazyTensor but with clearer interface, but it not done yet.

tvercaut commented 2 years ago

Many thanks @joanglaunes. That makes sense now.

jeanfeydy commented 2 years ago

Hi @joanglaunes,

Unfortunately, the .t() method is still broken in master. For instance, running the example from the docstring:

import torch
from pykeops.torch import LazyTensor

x, y = torch.randn(1000, 3), torch.randn(2000, 3)
x_i, y_j = LazyTensor( x[:,None,:] ), LazyTensor( y[None,:,:] )
K  = (- ((    x_i     -      y_j   )**2).sum(2) ).exp()  # Symbolic (1000,2000) Gaussian kernel matrix
K_ = (- ((x[:,None,:] - y[None,:,:])**2).sum(2) ).exp()  # Explicit (1000,2000) Gaussian kernel matrix
w  = torch.rand(1000, 2)
print( (K.t() @ w - K_.t() @ w).abs().mean() )

Fails with:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/data/jean/keops/pykeops/common/lazy_tensor.py", line 2204, in __matmul__
    Kv = Kv.sum(Kv.dim() - 2, **kwargs)  # Matrix-vector or Matrix-matrix product
  File "/data/jean/keops/pykeops/common/lazy_tensor.py", line 1777, in sum
    return self.reduction("Sum", axis=axis, **kwargs)
  File "/data/jean/keops/pykeops/common/lazy_tensor.py", line 733, in reduction
    res.fixvariables()  # Turn the "id(x)" numbers into consecutive labels
  File "/data/jean/keops/pykeops/common/lazy_tensor.py", line 291, in fixvariables
    str_cat_v = re.search(
                    r"Var\({},\d+,([012])\)".format(i), self.formula + self.formula2
                ).group(1)
AttributeError: 'NoneType' object has no attribute 'group'

Likewise, in some of my optimal transport scripts, I get the following stacktrace:

  ...
  File "/data/jean/geomloss/geomloss/examples/performances/benchmarks_ot_solvers.py", line 122, in plan_marginals
    B_j = b_j * (K_ij.t() @ a_i)  # Second marginal
  File "/data/jean/keops/pykeops/common/lazy_tensor.py", line 2260, in t
    y = self.tools.view(x, x.shape)
AttributeError: 'list' object has no attribute 'shape'

This error comes in a loop statement:

for x in self.variables:
    y = self.tools.view(x, x.shape)

So I understand that the type of the elements from self.variables has changed from array-like to list? Or maybe, the current implementation does not take into account the fact that some variables (e.g. constants) were given by the user as float numbers and lists instead of NumPy/PyTorch arrays?

I'm afraid of introducing new side effects if I fix it myself: do you think that you could have a look today? Otherwise, I'll try to be extra careful...

Best regards, Jean

joanglaunes commented 2 years ago

Hello @jeanfeydy , Actually for the bug with the first example (test from docstring), it was a small typo in the code. I have done the fix in master now. Can you try your optimal transport scripts again and see if it solves the issue also for them ?

jeanfeydy commented 2 years ago

Hi @joanglaunes, Yes, as far as I can tell, things run smoothly now: thanks a lot! I am still fixing/documenting bugs as I encounter them. I will let you know when everything is OK for the release :-)

jeanfeydy commented 2 years ago

Hi @joanglaunes,

Thanks to your last fix on the ranges, nearly everything is working now :-) However, I have found one last error with the transpose, running the code snippet below:

import torch
from pykeops.torch import LazyTensor

N = 10
M = 20
D = 3

x_i = torch.randn(N, D)
y_j = torch.randn(M, D)
F_i = torch.randn(N)
G_j = torch.randn(M)
a_i = torch.randn(N)
b_j = torch.randn(M)
blur = 0.1

x_i = LazyTensor(x_i[:, None, :])
y_j = LazyTensor(y_j[None, :, :])
F_i = LazyTensor(F_i[:, None, None])
G_j = LazyTensor(G_j[None, :, None])

# Cost matrix:
C_ij = ((x_i - y_j) ** 2).sum(-1) / 2

# Scaled kernel matrix:
K_ij = ((F_i + G_j - C_ij) / blur ** 2).exp()

A_i = a_i * (K_ij @ b_j)  # First marginal
B_j = b_j * (K_ij.t() @ a_i)  # Second marginal

This fails with the same error message:

  ...
  File "/data/jean/keops/pykeops/common/lazy_tensor.py", line 2260, in t
    y = self.tools.view(x, x.shape)
AttributeError: 'list' object has no attribute 'shape'

What do you think? Is there a way of making sure that we are not forgetting "any case", or is it too intricate?

Thanks again, Jean