getkeops / keops

KErnel OPerationS, on CPUs and GPUs, with autodiff and without memory overflows
MIT License
1.03k stars 65 forks source link

Does keops backpropagation exploit the "diagonal" structure of reduction operations ? #301

Open AdrienWohrer opened 1 year ago

AdrienWohrer commented 1 year ago

Hi, I am using keops/pytorch to implement an ordinary differential equation on a bunch of variables

The ODE is implemented with some numerical scheme, say Euler to simplify, with a structure of the form : Y[:, t+1] = Y[:, t] + epsilon G(Y[:, t]) X[n, t+1] = X[n,t] + epsilon sum_m F(X[n,t], Y[m,t]) for all n in 1...N the sum over m being implemented with a "sum" reduction over some keops symbolic formula.

The important information here is that the ODE governing Y is a "true" M-dimensional ODE, whose time derivative is given by some function G:R^M to R^M, whereas the ODE governing X is factorized across variables (x_n), i.e., each variable x_n is ruled by a 1-dimensional ODE involving only itself plus the Y variables. Indeed, this is the natural structure of keops reduction formulas, performed in parallel for each dimension of the non-reduced variable (here X).

I then consider some scalar quantity l=L(X), computed from the whole trajectory of variables X (that is, all scalar variables X[n,t] from t=0 to t=T). Through the above ODE, the whole trajectory X is itself a function of the initial condition (X0,Y0) only, and I want to estimate the differential of the final number l=L(X) w.r.t. the initial values Y0.

In theory, this gradient can be estimated by back-propagation. A naive description of the back-propagation in this case would be that there is a "global" function H : R^(N+M) -> R^(N+M) such that (X[:,t+1],Y[:,t+1]) = H( X[:,t], Y[:,t] ) and each back-propagation from time step "t+1" back to time step "t" involves multiplying the gradient by the differential matrix dH at this point, so an (N+M)*(N+M) matrix.

However, in my particular case, this general approach is a computational overkill, because the variables (x_n) have no direct interaction between them ; in other words the differential matrix dH is just diagonal, as for its X components. And thus, the naive back-propagation would imply multiplying by a huge (N+M)*(N+M) matrix mostly full of zeros.

So finally to my question : is keops' autograd (and its interaction with pytorch autgrad) optimized enough to exploit this particular "diagonal" structure associated with the reductions ? Can I rely on a simple pytorch command l.backward() to compute the gradient of l=L(X) w.r.t. Y0 ? Or will that be largely sub-optimal in the terms of computation time ? To stress the issue again, my X variables are much more numerous (N~=1000) than my Y variables (M~=100), so exploiting the independence structure of the X variables is crucial for me.