jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.97k stars 2.75k forks source link

Implement `scipy.linalg.sparse.onenormest` #11568

Closed cglwn closed 2 years ago

cglwn commented 2 years ago

scipy.linalg.sparse.onenormest estimates the 1-norm of a matrix, and is needed to compute matrix logarithms with inverse scaling and squaring for #5469. This function also has use cases in other matrix functions.

The implementation recommended in the matrix function catalog is the block method: https://epubs.siam.org/doi/10.1137/S0895479899356080.

williamberman commented 2 years ago

Hey, I was taking a look at this and hacking on it here. The original algorithm is pretty control flow heavy with lots of conditional variable updates and loop breaks so the implementation is pretty unappealing in order to get it to jit compile. I would assume my implementation is non-idiomatic jax.

I do have it passing the scipy test suite with a few small relaxations. i.e. the algorithm is stochastic and resamples 1 more than the scipy implementation in a test case.

I also benchmarked it here. For a 4096x4096 matrix, the jax implementation on a GPU is ~8x as fast as the scipy implementation. The jax implementation on a CPU is ~5.6x slower than the scipy implementation.

I would love to know if this is in the direction of something that’s mergeable and what sorts of changes it might need? I.e. maybe not jitting the whole function and using some python control flow might make it significantly cleaner.

cglwn commented 2 years ago

This looks great, I have a hacky implementation here based on the scipy implementation and ran into the same issues getting it to JIT. Looks like yours gets closer to jitting so is perhaps a better base.

I personally don't need an implementation that jits.

williamberman commented 2 years ago

Oh super sorry didn't realize you were already working on it or would have offered to collab. Yeah your implementation looks great, getting it to jit is super hard :)

I'm not super familiar with what jax's standards are for its scipy or numpy api. i.e. do you want everything exposed from jax.{numpy,scipy} to be jitted? Additionally, it might be possible to refactor the algorithm to make some of the control flow and state updates a bit more functional and to be written in more idiomatic jax.

I see a few options

  1. Stay as close to scipy implementation as possible with jitting as an after thought.
  2. Stay close to the scipy implementation with minor modifications to jit portions of the function.
  3. Jit the full function while staying close to the scipy implementation -- results in non-idiomatic jax but is easier to verify correct than 4
  4. Jit the full function by refactoring the original algorithm to write in more idiomatic jax

This is the first time I've ever hacked on jax so not familiar with which of those options would be preferred :)

Separate note, since the matrix sampling is random, it will have to break the scipy api and take a prngkey

cglwn commented 2 years ago

All good since I'm no longer working on it. It's probably best for a maintainer to chime in on the right path forward.

williamberman commented 2 years ago

Ah ok cool. Do you know who might be the right person to ping?

williamberman commented 2 years ago

@froystig do you know anyone who might be willing to take a look at this? many thanks in advance 🙏

cglwn commented 2 years ago

Ah ok cool. Do you know who might be the right person to ping?

I sadly do not. Since your implementation passes the scipy test cases, opening a PR might be the speediest way to get feedback.

williamberman commented 2 years ago

Ah ok cool. Do you know who might be the right person to ping?

I sadly do not. Since your implementation passes the scipy test cases, opening a PR might be the speediest way to get feedback.

No worries! Ah yeah, since I'm a bit concerned about the code quality and want some advice on direction I'll post something in discussions first when I get some time before opening a PR

froystig commented 2 years ago

Appreciate you both looking into this! Also cc @mattjj @jakevdp @shoyer who might find this interesting as well.

I only managed to take a quick look so far. Just an immediate thought, actually less to do with the details of implementation, more about realizing that this may be a substantial piece of code no matter how it's done... Now that you've worked on this some, do you think that the JAX core is the best eventual home for it? Or should it live in its own repo (maybe among related functions), maintained by its active users?

As broad context, we like to keep the JAX core focused in scope, and there's an obvious tradeoff between that and achieving coverage of all of scipy, especially its more difficult parts. We want to be sure we can do a good job of maintaining what we take on, so we've intentionally dialed back our coverage of some of scipy over time.

An example along these lines is that we plan to deprecate jax.scipy.minimize in favor of JAXopt, whose unconstrained minimization options will soon likely subsume jax.scipy.minimize.

Maybe this is another such example? What do you think?

froystig commented 2 years ago

@williamberman Yes, please don't hesitate to cross-post on the discussions for feedback on the code!

williamberman commented 2 years ago

@froystig Thank you so much for the color! Yes I think that division of responsibilities makes a lot of sense. If there were a more specific package I could contribute code for onenormest to, that would be great. The sparse matrix one norm estimator feels like it belongs in some sort of linalg++ jax package. I don't see anything that falls under than category in the google organization or the ecosystem blogpost linked in the readme. Do you know of an existing package it would be a good fit in?

shoyer commented 2 years ago

We've had some idle thoughts about making something like a "linalg++ jax package" (including for new iterative linear solvers) but nobody has stepped up to do it. The closest thing that currently exists is JAXopt.

williamberman commented 2 years ago

X-posted in discussions https://github.com/google/jax/discussions/12102 :)

williamberman commented 2 years ago

@shoyer Taking a look at JAXOpt, I can't tell off the top of my head if this algorithm is a good fit for the package, but I'll open an issue asking!

mblondel commented 2 years ago

I agree with @froystig's suggestion. For the time being, I think this code should live in the repository of the project that needs it. I wouldn't include it in JAXopt because it's not directly related to optimization and would become a substantial amount of code to maintain for us. Another potential issue : I don't think the JAX community has agreed on an API for linear operators yet.

williamberman commented 2 years ago

@mblondel agreed, looking closer at JAXopt it doesn't really look like it's a good fit. @froystig given there's no good place for the code in jax or its auxilliary repos, maybe it makes sense to close the issue?

froystig commented 2 years ago

Sounds good, I'll close. We can take away again that it'd be nice to have a home for linalg things like this, sooner or later. Separately, maybe we should add some guideline (e.g. in our docs) about the intended scope of jax.scipy, to clarify.

By the way, if you're motivated to contribute to JAX (or JAXopt etc.), then I'm sure we can find another thing. If this was more of an exploration in norm estimation, that's cool too—I hope the explorations continue. Thanks for the discussion!

williamberman commented 2 years ago

@froystig sg, yes was looking more to contribute to Jax than exploring norm estimation :) Yes would love to be pointed in the direction of any particulars!