getkeops / keops

KErnel OPerationS, on CPUs and GPUs, with autodiff and without memory overflows
https://www.kernel-operations.io
MIT License
1.03k stars 65 forks source link

Does keops backpropagation exploit the "diagonal" structure of reduction operations ? #301

Open AdrienWohrer opened 1 year ago

AdrienWohrer commented 1 year ago

Hi, I am using keops/pytorch to implement an ordinary differential equation on a bunch of variables

The ODE is implemented with some numerical scheme, say Euler to simplify, with a structure of the form : Y[:, t+1] = Y[:, t] + epsilon G(Y[:, t]) X[n, t+1] = X[n,t] + epsilon sum_m F(X[n,t], Y[m,t]) for all n in 1...N the sum over m being implemented with a "sum" reduction over some keops symbolic formula.

The important information here is that the ODE governing Y is a "true" M-dimensional ODE, whose time derivative is given by some function G:R^M to R^M, whereas the ODE governing X is factorized across variables (x_n), i.e., each variable x_n is ruled by a 1-dimensional ODE involving only itself plus the Y variables. Indeed, this is the natural structure of keops reduction formulas, performed in parallel for each dimension of the non-reduced variable (here X).

I then consider some scalar quantity l=L(X), computed from the whole trajectory of variables X (that is, all scalar variables X[n,t] from t=0 to t=T). Through the above ODE, the whole trajectory X is itself a function of the initial condition (X0,Y0) only, and I want to estimate the differential of the final number l=L(X) w.r.t. the initial values Y0.

In theory, this gradient can be estimated by back-propagation. A naive description of the back-propagation in this case would be that there is a "global" function H : R^(N+M) -> R^(N+M) such that (X[:,t+1],Y[:,t+1]) = H( X[:,t], Y[:,t] ) and each back-propagation from time step "t+1" back to time step "t" involves multiplying the gradient by the differential matrix dH at this point, so an (N+M)*(N+M) matrix.

However, in my particular case, this general approach is a computational overkill, because the variables (x_n) have no direct interaction between them ; in other words the differential matrix dH is just diagonal, as for its X components. And thus, the naive back-propagation would imply multiplying by a huge (N+M)*(N+M) matrix mostly full of zeros.

So finally to my question : is keops' autograd (and its interaction with pytorch autgrad) optimized enough to exploit this particular "diagonal" structure associated with the reductions ? Can I rely on a simple pytorch command l.backward() to compute the gradient of l=L(X) w.r.t. Y0 ? Or will that be largely sub-optimal in the terms of computation time ? To stress the issue again, my X variables are much more numerous (N~=1000) than my Y variables (M~=100), so exploiting the independence structure of the X variables is crucial for me.