Closed terhorst closed 1 year ago
Edit: I just discovered the AbstractTerm
class which seems promising, but I need to stare at it more.
I think you want to write ODETerm(lambda t, y, args: A(t) @ y)
? You've missed off the y
accidentally. Moreover as you're now explicitly calling __matmul__
then things should work for you.
(The tensordot you've spotted is something else entirely! Just a sum over stages in the Runge-Kutta method.)
Ack, I totally misunderstood the error I was getting. Sorry for the dumb question!
Hi Patrick,
Thanks for the cool library! I have a question about how to use
diffrax
to solve an ODE that has Kronecker product structure. Specifically, I'm wanting to solve $$f'(t) = A(t)f(t)$$ where $A(t)\in \mathbb{R}^{N \times N}$ is a matrix. (Basically matrix exponential but with a non-constant rate matrix.) In my application, $A$ is very large (also sparse), but it can be written as a sum of Kronecker products, $A=\sum_i \bigotimesj A{ij}$ where $A_{ij} \in \mathbb{R}^{n_j\times n_j}$. To evaluate the product $A(t) x$ for $x \in \mathbb{R}^{n_1\times\cdots nK}$, it is much cheaper to a) individually multiply each of the $A{ij}$ against corresponding axes of $x$, versus b) materializing the full Kronecker product, flattening $x$, and performing a standard matrix-vector multiply.Ok so here's my question: I've written a simple
KronProd
class that implements__matmul__
in such a way thatA @ x
does the correct thing when passed a conforming tensorx
. My hope had been to plug this directly intodiffrax
, basically just asODETerm(lambda y, t, args: A(t))
whereA(t)
returns an instance of my custom class.Unfortunately this doesn't quite work: as best I can tell, every solver eventually ends up invoking https://github.com/patrick-kidger/diffrax/blob/0b93a3c108cff20c201da7b81e141dceff637f4f/diffrax/solver/base.py#L20-L23 which is "hard-wired" to call
jnp.tensordot
. This is equivalent toa @ bi
in a certain way, but doesn't work for me since it ignores__matmul__
entirely. AFAICT there's no way to easily override this behavior given the current design of diffrax.Can you think of any workaround here? Thanks again!