getkeops / keops

KErnel OPerationS, on CPUs and GPUs, with autodiff and without memory overflows
https://www.kernel-operations.io
MIT License
1.03k stars 65 forks source link

Mask a LazyTensor with a sparse matrix #267

Closed skyshine102 closed 1 year ago

skyshine102 commented 1 year ago

I have a tensor X of shape [b, N, d] and a mask A of shape [N, N]. I would like to compute the square of Euclidean distance and mask it by A lazily. From the tutorial, I can get euclidean distance K by

x_i = LazyTensor(X.view(b, 1, N. d))
x_j = LazyTensor(X.view(b, N, 1. d))
K_ij = ((x_i - x_j)**2).sum(-1) # LazyTensor shape (b, N, N)

# Then I want to mask K with A by elementwise multiplication
# Something like W_ij = A_ij * K_ij  
A_ij =  LazyTensor(A.expand(b, N, N)) #--> fail to create LazyTensor of A of shape [b, N, N], only [..., M,1,D] / [..., 1,N,D]/  [..., 1,1,D] is allowed
W_ij = A_ij * K_ij 
out = W_ij.sum(-1)

Is there any way to achieve the above operations? Looking for help.

jeanfeydy commented 1 year ago

Hi @skyshine102,

Thanks for your interest in the library! Does your mask A has some kind of structure?

If it has a "mathematical" structure (e.g. if it comes from a threshold on some kind of distance function that depends on x_i and x_j), you should be able to encode A_ij as a LazyTensor, just like K_ij. Alternatively, if your mask has a block-wise, triangular or band-like structure, you may be able to implement it efficiently using our syntax for block-sparsity masks.

On the other hand, if your mask is a "random" (N,N) binary matrix, I'm afraid that KeOps won't be able to help: as discussed e.g. in #203, KeOps speed-ups come from the fact that we avoid the explicit storage and transfer of large [N,N] variables. If you cannot "compress" the information that is encoded by your mask A by using its structure, then relying on a bruteforce PyTorch implementation may be the best possible option. What do you think?

Best regards, Jean

skyshine102 commented 1 year ago

Hi Jean,

Thank you for ur prompt and detailed reply! My mask A comes from an independent source (a graph relation). It may have a cluster-like structure but I need to check further. But even if I rearrange my A through node clustering as in Figure 1 of this, I cannot guarantee that my A has block-like structure.

Best regards, Jeremy

jeanfeydy commented 1 year ago

Hi @skyshine102,

I see: in that case, it is likely that KeOps is not the right tool for the job. (The only solution would be to compute an embedding of your graph in a vector space and encode your mask as a kernel matrix - but this is likely too cumbersome if the "symbolic formula" that your are trying to evaluate is just a squared Euclidean distance.) You may have better luck representing your matrix W_ij as a sparse matrix and manipulating it using e.g. the PyG library. I use the torch_scatter extension quite extensively whenever I deal with "real" graphs: since all those extensions run smoothly via the PyTorch ecosystem, it is now very easy to pick the "best" numerical routine for any given task.

Best regards and good luck for your project! Jean