patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.44k stars 129 forks source link

Derivative with Kronecker product structure #243

Closed terhorst closed 1 year ago

terhorst commented 1 year ago

Hi Patrick,

Thanks for the cool library! I have a question about how to use diffrax to solve an ODE that has Kronecker product structure. Specifically, I'm wanting to solve $$f'(t) = A(t)f(t)$$ where $A(t)\in \mathbb{R}^{N \times N}$ is a matrix. (Basically matrix exponential but with a non-constant rate matrix.) In my application, $A$ is very large (also sparse), but it can be written as a sum of Kronecker products, $A=\sum_i \bigotimesj A{ij}$ where $A_{ij} \in \mathbb{R}^{n_j\times n_j}$. To evaluate the product $A(t) x$ for $x \in \mathbb{R}^{n_1\times\cdots nK}$, it is much cheaper to a) individually multiply each of the $A{ij}$ against corresponding axes of $x$, versus b) materializing the full Kronecker product, flattening $x$, and performing a standard matrix-vector multiply.

Ok so here's my question: I've written a simple KronProd class that implements __matmul__ in such a way that A @ x does the correct thing when passed a conforming tensor x. My hope had been to plug this directly into diffrax, basically just as ODETerm(lambda y, t, args: A(t)) where A(t) returns an instance of my custom class.

Unfortunately this doesn't quite work: as best I can tell, every solver eventually ends up invoking https://github.com/patrick-kidger/diffrax/blob/0b93a3c108cff20c201da7b81e141dceff637f4f/diffrax/solver/base.py#L20-L23 which is "hard-wired" to call jnp.tensordot. This is equivalent to a @ bi in a certain way, but doesn't work for me since it ignores __matmul__ entirely. AFAICT there's no way to easily override this behavior given the current design of diffrax.

Can you think of any workaround here? Thanks again!

terhorst commented 1 year ago

Edit: I just discovered the AbstractTerm class which seems promising, but I need to stare at it more.

patrick-kidger commented 1 year ago

I think you want to write ODETerm(lambda t, y, args: A(t) @ y)? You've missed off the y accidentally. Moreover as you're now explicitly calling __matmul__ then things should work for you.

(The tensordot you've spotted is something else entirely! Just a sum over stages in the Runge-Kutta method.)

terhorst commented 1 year ago

Ack, I totally misunderstood the error I was getting. Sorry for the dumb question!