jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.46k stars 2.8k forks source link

Wrappers for`scipy.linalg`quadratic control solvers (lyapunov, ARE) #19109

Open jessegrabowski opened 10 months ago

jessegrabowski commented 10 months ago

Please:

I'm interested in implementing solve_discrete_lyapunov, solve_continuous_lyapunov, and solve_discrete_are from scipy.linalg as JAX primitives. My particular use-case is Kalman filtering -- these functions are handy for computing initial and steady-state covariance matrices, but they also have wide application in linear-quadratic control applications. There are gradients computed in this paper, https://arxiv.org/pdf/2011.11430.pdf and I've also done implementations in PyTensor here and here. I'm relying heavily on compiling pytensor graphs to JAX for high-performance scans, and not having these functions is a bit of a pain-point for me at the moment.

I didn't see these functions from a quick search of the codebase, but I just wanted to check that 1) a contribute would be welcome, and 2) they didn't exist elsewhere in the JAX ecosystem before starting a PR.

shoyer commented 10 months ago

How do you think these solvers match up according to our rubric for JAX scipy wrappers? https://jax.readthedocs.io/en/latest/jep/18137-numpy-scipy-scope.html

jessegrabowski commented 10 months ago

I think they're on the margin. They fail axis 1 spectacularly, but arguably pass the other axes, with varying levels of difficultly in making those arguments. I think they pass on 2, 3, 5, and arguably on 6 -- github code search finds 500-1000 snippets using each of solve_discrete_lyapunov and solve_discrete_are, so it's clearly not as popular as linalg.solve, but more widely used than bessel_jn. I guess the weakest case is on 4, ideally I'd hope to just wrap up some calls to LAPACK for forward computation together with some gradients, but this is likely to be more complicated than I realize (hardware targeting issues? introduction of additional package requirements? I have no idea if either of these would be issues, but I can imagine that they could be).

On the other hand, quadratic control generally and Kalman filtering specifically aren't exactly niche topics in scientific computing, so I'm sure these functions would see some use if they were available. Plus scipy.linalg is in scope. But I could see them belonging in something more like jax-opt, though.