getkeops / keops

KErnel OPerationS, on CPUs and GPUs, with autodiff and without memory overflows
https://www.kernel-operations.io
MIT License
1.05k stars 64 forks source link

Elaborate message passing #247

Open Jostarndt opened 2 years ago

Jostarndt commented 2 years ago

Hey there, first of all: thank you very much for your amazing work, i am super impressed by and happy about KeOps. On your homepages section I stumbled upon message passing layers as a possible application, and am not sure if I understood everything correctly. Unfortunately there is no tutorial on this topic given, so I thought just asking here might be an option - sorry if this should have been clear by reading your NeurIPS paper, but I have some difficulties. So maybe a longer example including code could be helpful to me and further enlarge your audience.

One of my main questions concerns the concrete meaning of semi-symbolic in the context of message passing. If I follow your example of message passing, and update some aggregation a_i by a sum of filter over neighbors by multiplying a sparse adjacency matrix A with a filter-vector F, which might be computationally expensive. Does KeOps actually compute also the plenty unnecessary (due to the sparsity of A) entries in the filter-vector F? Or does "semi-symbolic" mean that the sparsity prevents the unnecessary computations? Or how would you suggest this summation to be done?

My next question regarding the same would be: what If i do not want a summation but some nonlinear aggregation, such as max - is there any suggested implementation for this with KeOps I havent found yet?

And is there any way to access code of your comparison of KeOps and pytorch-geometric regarding this?

I would extremely appreciate it to further understand KeOps as it seems to deliver exactly what I am looking for, thank you very much for the so far incredible documentation and programming you did!

jeanfeydy commented 2 years ago

Hi @Jostarndt,

Thanks for your kind words :-)

To answer your questions, out-of-the-box, yes, KeOps computes all pair-wise values for filter-values F(x[i], y[j]). This default behavior has quadratic complexity O(#i * #j) but an extremely good “constant” as long as F is not too complex - as a rule of thumb, less than 50-100 arithmetic operations.

In practice, this means that the default “bruteforce” mode of KeOps has two main use cases:

Going further, KeOps also supports block-wise sparsity masks: it is actually possible to “skip” unnecessary computations using a “coarse” description of your connectivity structure. This is especially useful if you manipulate “graphs” with more than 10k nodes.

With respect to your question about max(), it is fully supported - although not yet differentiable I think? A work-around is discussed here. LogSumExps (fully differentiable) are usually a better option anyway. Is there anything else you would like to know?

With respect to code examples:

Does this answer your questions? Please let me know if you would like to know anything :-)

Best regards, Jean

jeanfeydy commented 2 years ago

P.S.: Oh, and one last use case: you can simply use KeOps to speed-up the construction of the K-NN graph :-)