Probabilistic solvers for differential equations in JAX. Adaptive ODE solvers with calibration, state-space model factorisations, and custom information operators. Compatible with the broader JAX scientific computing ecosystem.
I don't know if you have seen this new library, but it seems like it could provide some useful efficient Linear Operators for probabilistic ode solvers.
CoLA is a framework for scalable linear algebra, automatically exploiting the structure often found in machine learning problems and beyond. CoLA supports both PyTorch and JAX.
I don't know if you have seen this new library, but it seems like it could provide some useful efficient Linear Operators for probabilistic ode solvers.
CoLA is a framework for scalable linear algebra, automatically exploiting the structure often found in machine learning problems and beyond. CoLA supports both PyTorch and JAX.
https://github.com/wilson-labs/cola