patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.44k stars 130 forks source link

Diagonal/Sparse Control Term #485

Closed SoerenNagel closed 3 months ago

SoerenNagel commented 3 months ago

Hi,

I have to solve a bunch of fairly high dimensional SDE's, but I only have a diagonal diffusion matrix/control term. As it turns out this is a fairly big bottleneck in my case, because I have already optimized the ode term.

Is there a way to use a sparse matrix to speed up the integration? I tried to sparsify() diffeqsolve with the jax.sparse module, but I could not get it to work.

Passing just the diagonal as vector results in a the dot product being used for every dimension and one gets the same noise realization for each dimension. Is there a way to cleverly pass only the diagonal array?

Thanks for the help!

SoerenNagel commented 3 months ago

I just found the example using a lineax linear operator in the docs.