getkeops / keops

KErnel OPerationS, on CPUs and GPUs, with autodiff and without memory overflows
https://www.kernel-operations.io
MIT License
1.03k stars 65 forks source link

Batched data for lazt tensor solve #274

Open iscoyizj opened 1 year ago

iscoyizj commented 1 year ago

Hi, I'm facing difficulties in following the kernel interpolation tutorial here:

http://kernel-operations.io/keops/_auto_tutorials/interpolation/plot_RBF_interpolation_torch.html#sphx-glr-auto-tutorials-interpolation-plot-rbf-interpolation-torch-py

for batched data.

The code is here:

def batch_gaussian_kernel(x, y, sigma=0.1):
    x_i = LazyTensor(x[:, :, None, :])  # (M, 1, 1)
    y_j = LazyTensor(y[:, None, :, :])  # (1, N, 1)
    D_ij = ((x_i - y_j) ** 2).sum(-1)  # (M, N) symbolic matrix of squared distances
    return (-D_ij / (2 * sigma ** 2)).exp()  # (M, N) symbolic Gaussian kernel matrix

x = torch.rand(10, 5, 2)
b = torch.rand(10, 5, 2)
c = torch.rand(10,5,1)
k_XX = batch_gaussian_kernel(x, x)
a = k_XX.solve(c)

And the error information is here

Traceback (most recent call last):
  File "", line 76, in <module>
    a = k_XX.solve(c)
  File "/home/anaconda3/envs/lib/python3.7/site-packages/pykeops/common/lazy_tensor.py", line 821, in solve
    x=other,axis=0
  File "/home/anaconda3/envs/lib/python3.7/site-packages/pykeops/torch/lazytensor/LazyTensor.py", line 72, in lt_constructor
    return LazyTensor(x=x, axis=axis, is_complex=is_complex)
  File "/home/anaconda3/envs/lib/python3.7/site-packages/pykeops/torch/lazytensor/LazyTensor.py", line 64, in __init__
    super().__init__(x=x, axis=axis)
  File "/home/anaconda3/envs/lib/python3.7/site-packages/pykeops/common/lazy_tensor.py", line 189, in __init__
    "'axis' parameter should not be given when 'x' is a 3D tensor."
ValueError: 'axis' parameter should not be given when 'x' is a 3D tensor.

Thanks a lot