Closed cglwn closed 2 years ago
Hey, I was taking a look at this and hacking on it here. The original algorithm is pretty control flow heavy with lots of conditional variable updates and loop breaks so the implementation is pretty unappealing in order to get it to jit compile. I would assume my implementation is non-idiomatic jax.
I do have it passing the scipy test suite with a few small relaxations. i.e. the algorithm is stochastic and resamples 1 more than the scipy implementation in a test case.
I also benchmarked it here. For a 4096x4096 matrix, the jax implementation on a GPU is ~8x as fast as the scipy implementation. The jax implementation on a CPU is ~5.6x slower than the scipy implementation.
I would love to know if this is in the direction of something that’s mergeable and what sorts of changes it might need? I.e. maybe not jitting the whole function and using some python control flow might make it significantly cleaner.
This looks great, I have a hacky implementation here based on the scipy implementation and ran into the same issues getting it to JIT. Looks like yours gets closer to jitting so is perhaps a better base.
I personally don't need an implementation that jit
s.
Oh super sorry didn't realize you were already working on it or would have offered to collab. Yeah your implementation looks great, getting it to jit is super hard :)
I'm not super familiar with what jax's standards are for its scipy or numpy api. i.e. do you want everything exposed from jax.{numpy,scipy} to be jitted? Additionally, it might be possible to refactor the algorithm to make some of the control flow and state updates a bit more functional and to be written in more idiomatic jax.
I see a few options
This is the first time I've ever hacked on jax so not familiar with which of those options would be preferred :)
Separate note, since the matrix sampling is random, it will have to break the scipy api and take a prngkey
All good since I'm no longer working on it. It's probably best for a maintainer to chime in on the right path forward.
Ah ok cool. Do you know who might be the right person to ping?
@froystig do you know anyone who might be willing to take a look at this? many thanks in advance 🙏
Ah ok cool. Do you know who might be the right person to ping?
I sadly do not. Since your implementation passes the scipy test cases, opening a PR might be the speediest way to get feedback.
Ah ok cool. Do you know who might be the right person to ping?
I sadly do not. Since your implementation passes the scipy test cases, opening a PR might be the speediest way to get feedback.
No worries! Ah yeah, since I'm a bit concerned about the code quality and want some advice on direction I'll post something in discussions first when I get some time before opening a PR
Appreciate you both looking into this! Also cc @mattjj @jakevdp @shoyer who might find this interesting as well.
I only managed to take a quick look so far. Just an immediate thought, actually less to do with the details of implementation, more about realizing that this may be a substantial piece of code no matter how it's done... Now that you've worked on this some, do you think that the JAX core is the best eventual home for it? Or should it live in its own repo (maybe among related functions), maintained by its active users?
As broad context, we like to keep the JAX core focused in scope, and there's an obvious tradeoff between that and achieving coverage of all of scipy, especially its more difficult parts. We want to be sure we can do a good job of maintaining what we take on, so we've intentionally dialed back our coverage of some of scipy over time.
An example along these lines is that we plan to deprecate jax.scipy.minimize
in favor of JAXopt, whose unconstrained minimization options will soon likely subsume jax.scipy.minimize
.
Maybe this is another such example? What do you think?
@williamberman Yes, please don't hesitate to cross-post on the discussions for feedback on the code!
@froystig Thank you so much for the color! Yes I think that division of responsibilities makes a lot of sense. If there were a more specific package I could contribute code for onenormest to, that would be great. The sparse matrix one norm estimator feels like it belongs in some sort of linalg++
jax package. I don't see anything that falls under than category in the google organization or the ecosystem blogpost linked in the readme. Do you know of an existing package it would be a good fit in?
We've had some idle thoughts about making something like a "linalg++ jax package" (including for new iterative linear solvers) but nobody has stepped up to do it. The closest thing that currently exists is JAXopt.
X-posted in discussions https://github.com/google/jax/discussions/12102 :)
@shoyer Taking a look at JAXOpt, I can't tell off the top of my head if this algorithm is a good fit for the package, but I'll open an issue asking!
I agree with @froystig's suggestion. For the time being, I think this code should live in the repository of the project that needs it. I wouldn't include it in JAXopt because it's not directly related to optimization and would become a substantial amount of code to maintain for us. Another potential issue : I don't think the JAX community has agreed on an API for linear operators yet.
@mblondel agreed, looking closer at JAXopt it doesn't really look like it's a good fit. @froystig given there's no good place for the code in jax or its auxilliary repos, maybe it makes sense to close the issue?
Sounds good, I'll close. We can take away again that it'd be nice to have a home for linalg things like this, sooner or later. Separately, maybe we should add some guideline (e.g. in our docs) about the intended scope of jax.scipy
, to clarify.
By the way, if you're motivated to contribute to JAX (or JAXopt etc.), then I'm sure we can find another thing. If this was more of an exploration in norm estimation, that's cool too—I hope the explorations continue. Thanks for the discussion!
@froystig sg, yes was looking more to contribute to Jax than exploring norm estimation :) Yes would love to be pointed in the direction of any particulars!
scipy.linalg.sparse.onenormest
estimates the 1-norm of a matrix, and is needed to compute matrix logarithms with inverse scaling and squaring for #5469. This function also has use cases in other matrix functions.The implementation recommended in the matrix function catalog is the block method: https://epubs.siam.org/doi/10.1137/S0895479899356080.