getkeops / keops

KErnel OPerationS, on CPUs and GPUs, with autodiff and without memory overflows
https://www.kernel-operations.io
MIT License
1.04k stars 63 forks source link

Q: How to add a scalar to a diagonal of a LazyTensor? #124

Open awav opened 3 years ago

awav commented 3 years ago

Hi all,

I would like to add a scalar to a diagonal of a LazyTensor. As far as I understand issue #24 is relevant. But is there another way of adding a value to a matrix diagonal? I would prefer to use a stable release rather than hacking around with raw keops code :)

A simple example would be:

n = ...
...
xi = LazyTensor(x[:, None, :])
yj = LazyTensor(y[None, :, :])
Dij = ((xi - yj) ** 2).sum(dim=2)
Kij = s * (-Dij / (l ** 2)).exp() + torch.eye(n)
jeanfeydy commented 3 years ago

Hi @awav ,

Thanks for your interest in the library! As of today, the simplest way of doing so is to introduce extra variables to encode the indices i, j and add an explicit boolean check i == j to the formula as follows:

import torch
from pykeops.torch import LazyTensor

N, D = 10000, 10

# Create the arrays:
x = torch.randn(N, D).cuda()
y = torch.randn(N, D).cuda()
i = torch.arange(N).float().cuda()
j = torch.arange(N).float().cuda()
s, l = 100, .5

# Encode as symbolic LazyTensors:
xi = LazyTensor(x[:, None, :])
yj = LazyTensor(y[None, :, :])
i = LazyTensor(i[:, None, None])
j = LazyTensor(j[None, :, None])

# Gaussian kernel matrix:
Dij = ((xi - yj) ** 2).sum(dim=2)
Kij = s * (-Dij / (l ** 2)).exp()

# Add the diagonal (with an overkill comparison...):
K_plus_Id_ij = Kij + (.5 - (i - j)**2).step()

# Check:
print(Kij.sum(1))
print(K_plus_Id_ij.sum(1))

Needless to say, this heavy solution is sub-optimal! (Note that for maximum performance, we should also divide x and y by l prior to the LazyTensor computation: this allows us to reduce the number of divisions from N**2 to 2 * N.)

In most practical use cases, a better solution would be to use the SciPy abstraction for LinearOperators, as illustrated in this tutorial... But this is a NumPy-only solution, unfortunately. Going forward, we will probably have to "clone" this part of the SciPy package to provide a unified interface for both NumPy and PyTorch, plus hopefully JAX and TF at some point in the future. We're also planning to improve our support for batch computations, band-diagonal and triangular matrices as discussed in #121.

If you have any thoughts on this, please let us know! Now that the core engine of the library is mature, we're planning to add support for higher-level features. If you believe that we could do something to help your projects, we'll be happy to take it into account in our plans :-)

Best regards, Jean