getkeops / keops

KErnel OPerationS, on CPUs and GPUs, with autodiff and without memory overflows
https://www.kernel-operations.io
MIT License
1.03k stars 65 forks source link

Backpropagation through `.min()` #269

Open lostmsu opened 1 year ago

lostmsu commented 1 year ago

Just wanted to create a tracking issue for this warning: https://github.com/getkeops/keops/blob/7d68fd9da5887c217937a35210c740b2b9f9f79d/pykeops/pykeops/tutorials/a_LazyTensors/plot_lazytensors_a.py#L212

The documentation also lacks explanation of how to use argmin as suggested to backprop. Simply using the indices returned by argmin would require recomputing the values using PyTorch, and that negates the benefit of using keops in the first place.

jeanfeydy commented 1 year ago

Hi @lostmsu,

Thanks for your interest in this library! We (@bcharlier, @joanglaunes and myself) are currently busy with the start of the academic year, but smoothing out such edges is certainly on our todo-list for the year. (The plan is to introduce a new SymbolicArray wrapper that will be compatible with the new standard array API, while keeping LazyTensors for backward compatibility.)

In any case, please note that for .min() reductions, recomputing the values using the KeOps-indices and PyTorch does not negate the benefit of using KeOps. In the tutorial that you linked here, the fast KeOps routines are used for the costly O(M*N) search of the nearest neighbor indices. The PyTorch re-computation of the distances (in a fully differentiable way) is only O(M) and should have a negligible impact on run times. As a consequence, adding explicit support for backpropagation through .min() is important for usability, but not absolutely necessary.

What do you think? Best regards, Jean

lostmsu commented 1 year ago

I was not looking at that specific example. A minimal repro for my problem with min is a matmul followed by one. The hope was that keops would let me do that without constructing the full matmul output.

But in order to use indices directly without reinventing what keops did to compute them, I'd need to essentially do gather(matmul(...), indices).