getkeops / keops

KErnel OPerationS, on CPUs and GPUs, with autodiff and without memory overflows
https://www.kernel-operations.io
MIT License
1.02k stars 65 forks source link

Weighted sums reduction #354

Open DavidLapous opened 5 months ago

DavidLapous commented 5 months ago

Hi, and many thanks for your work. I have the following (reduced) problem, which I didn't manage to encode in pykeops :

On which I want to evaluate the convolutions against (multiple, but lets say 1 for now) gaussian measure,

Formally, I want something like $$\mathcal N(0,K) \star \mu_{\text{data}}(x_i) \quad \text{for } xi \in x, \text{data}$$ where $$\mu = \sum{x\in \mathrm{pts}} \delta_x \mathrm{weights}[x]$$ So my code is as follows:

My current workaround is to do a (python) for loop over num_data (and the other axis that I removed for this example), but this starts to be slow. Is there a way to do this with pykeops ? I didn't find anything in the api documentation.

jeanfeydy commented 5 months ago

Hi @DavidLapous ,

Thanks for your interest in the library! As far as I can tell, you are trying to compute a batched Gaussian convolution, which is very well supported.

Could you maybe write a full "minimal working example" (with realistic values for num_data, num_pts, num_x and D, since their orders of magnitude matter to select an optimal implementation) of what you are trying to compute?

Then, I will be able to give you specific performance tips.

Best regards, Jean

DavidLapous commented 4 months ago

Hi @jeanfeydy, thanks for your response :)

I think I found a workaround, it seems to work but I'm not super sure as it relies on a "non-valid" operation.

Basically, a simplified version of what I'm doing is something like this:

import numpy as np
num_data = 100 # range : 100 - 5 000
mean_num_pts = 200 # range : 100 - 5 000
D = 2 # less than 5 in general
num_x = 500 # range : num_pts - 1e9
dtype=np.float32

# each data doesn't have the same number of diracs in the measure.
# we unragg the tensor, and the multiplication by the weight in the end will 
# kill the dummy variables. It's fine as the variance is "small" in practice.
num_pts_variance = 100
num_pts_per_data = np.random.normal(mean_num_pts,num_pts_variance, size=(num_data,)).astype(int) 
max_num_pts = num_pts_per_data.max()
pts = np.random.uniform(size=(num_data,max_num_pts, D)).astype(np.float32) ## unragged pts
weights = np.random.choice([-2,-1,1,2], size=(num_data,max_num_pts), replace=True).astype(np.float32)
for i,data_num_pts in enumerate(num_pts_per_data):
    weights[i,data_num_pts:]=0 # kills the dummy variables
# The signed point measure is \sum \delta_x *w for x,w in zip(pts,weights)

# the points on which to evaluate the convolution
x = np.random.uniform(size = (num_data, num_x,D)).astype(np.float32)

from pykeops.numpy import LazyTensor
## (num_data, max_num_pts, num_x, D)
lazy_diracs = LazyTensor(pts[:,:,None,:])
lazy_x      = LazyTensor(x[:,None,:,:])

K = np.random.uniform(size=(D,D))
K = K.T@K + np.diag([.01]*D) # some covariance matrix inverse
K = K.astype(np.float32)

## batch convolutions
z = lazy_diracs - lazy_x
exponent = -(z.weightedsqnorm(K.flatten()))/2
exponent = exponent.exp()
assert exponent.shape == (num_data,max_num_pts, num_x)
assert weights.shape == (num_data,max_num_pts,)

convolutions = (exponent*weights[:,:,None,None]).sum(1)
print(convolutions.shape) # (num_pts, num_x, 1)

The only part that changed from my last part of the code is on the last line

convolutions = (exponent*weights[:,:,None,None]).sum(1)

where I added a None to the weights. (Pure luck on this haha) This should not work as , in this example

are not broadcastable one onto the other. I'm not really sure what's happening under the hood, but I the "lazy constraint" which makes weights[:,:,None] non-Lazy-ifiable, and the fact that adding a None makes Lazy and the computation valid makes this very fortunate.

I only tested the computation on small examples, but this works. Can you confirm that I'm not doing anything wrong ? I'm also curious on the what's the wizardry that makes this work.

Also, in my setup, there is another axis (usually small), for both the covariance matrices, and the data, but I think this is fine. The "for loop" bottleneck is the "num_data" I think. I'm also using the torch backend, which, fortunately, has the same behaviour.

Kind regards, -- David