jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.03k stars 2.75k forks source link

Implement returning convergence info in scipy.sparse.linalg solvers #7909

Open norabelrose opened 3 years ago

norabelrose commented 3 years ago

My feature request is simple: Instead of always returning None in the info field, the the matrix-free linear solvers in jax.scipy.sparse.linalg should return some sort of container object (i.e. a namedtuple or a dict) that contains converged, singular, and/or num_iter fields which could be used to determine whether the algorithm actually converged.

My use case is that I am currently implementing a boundary value problem solver in JAX using BiCGSTAB to find the Newton step direction, like this:

residuals, jvp_fun = jax.linearize(global_system, y_and_p)
direction = jax.scipy.sparse.linalg.bicgstab(A=jvp_fun, b=residuals, x0=direction, maxiter=100)[0]

The solver actually appears to work quite well, except that I have no way of knowing when the Jacobian is singular. This makes the function essentially unusable for me, so I'm finding that I have to copy and paste the implementation of jax.scipy.sparse.linalg.bicgstab into my project and tweak it so that it returns the number of iterations. If the number of iterations is equal to maxiter, I assume it didn't converge. Clearly, though, this is suboptimal and should be relatively easy to fix. I would be willing to contribute the PR to implement this functionality if need be, but I wanted to post on here first to make sure there wasn't some major reason why this hasn't already been implemented.

PhilipVinc commented 3 years ago

Hi @norabelrose

This issue is a duplicate of #4322. I have opened a week ago PR #7825 to address exactly this issue.

The main issue right now is discussing with google's people what interface to implement: jax implements the standard scipy interface is to return a single number which is 0 if convergence happened and niter if convergence failed. Unfortuntaely this discards the number of steps taken to converge if convergence happens, which is unfortunate.

Scipy developers are aware of this issue tracked by https://github.com/scipy/scipy/issues/10474 , and are willing to accept a weird hack to always return the number of steps in a backwards-compatible manner, however someone should do that work.

The question is now what to do in Jax: implement the same (ugly) interface that scipy will implement to maintain scipy compatibility, or implement a named-tuple/dict design (like I sketched in my PR) which is not compatible but easier to use?

If some core dev was to answer this question, i could finalize my PR.

jakevdp commented 3 years ago

Thanks @PhilipVinc - I think you've hit the core of the issue: jax.scipy by design implements the API of scipy. There is no doubt that the scipy API could be improved in places, but I'm not sure that the JAX mirror is the place to do that.

In terms of jax.scipy.optimize in particular, this is one reason for the creation of the JAXopt package, which provides a new well-thought-out API for optimization in JAX. If you're looking for improved optimization APIs, we will likely not diverge much from scipy.optimize within jax.scipy.optimize, so I'd suggest taking a look at JAXopt.

hawkinsp commented 3 years ago

And indeed, I think we're considering removing jax.scipy.optimize in favor of jaxopt (or just making it a thin shim that calls jaxopt if you have it installed).

froystig commented 3 years ago

cc @shoyer and @mblondel for the move to jaxopt-backed optimize.

shoyer commented 3 years ago

I think my loose preference here would be to do an ugly backwards compatible work-around like that suggested in https://github.com/scipy/scipy/issues/10474#issuecomment-915249531. I wouldn't actually subclass from tuple, but we could certainly do unpacking/indexing like a pair of elements for backwards compatibility.

If we are going to go to the trouble of setting up a new module with its own calling conventions, then I think going to a class based interface like that in JAXopt could make a lot of sense for more flexibility. It's a lot more extensible to be able to call init/update methods from your own loops rather than to rely upon callbacks like SciPy, e.g., if you wanted to do autodiff through a solver step to optimize the parameters of a preconditioner.

PhilipVinc commented 3 years ago

I was not discussing scipy.optimize at all (nor was @norabelrose) because I don't really know this part of the API: we were discussing scipy.sparse.linalg.[cg/bicgstab/..], though the arguments are probably similar.

I'm not even sure if scipy.sparse.linalg fits within the realms of jaxOpt (It seems to me it does not? though some solvers with a different API could be re-exported from there).

@shoyer It seems to me this approach is the one that has the most consensus? How would you go on about this? Would it entail defining a custom dataclass/PyTree that unpacks/indexes like a tuple but that actually contains a whole lot of extra information?

Could we put tentatively agree on what extra information to supply, if that is what you propose?

mblondel commented 3 years ago

I would be in favor of following what scipy does, since jax.scipy is supposed to be a port of it.

Adding iterative linear system solvers such as CG to JAXopt could make sense at some point. I agree that they fit well in the init/update API that we have.

shoyer commented 3 years ago

I'm not even sure if scipy.sparse.linalg fits within the realms of jaxOpt (It seems to me it does not? though some solvers with a different API could be re-exported from there).

I agree, it's not entirely clear that JAXopt is the right place to put this. Perhaps we need a new "JAX iterative solvers" library (like IterativeSolvers.jl) or could make a new sub-module jax.solvers. I was more referring to the API design rather than where to actually put the code.

@shoyer It seems to me this approach is the one that has the most consensus? How would you go on about this? Would it entail defining a custom dataclass/PyTree that unpacks/indexes like a tuple but that actually contains a whole lot of extra information?

Yes, this was my thinking exactly. The extra fields should probably include at least the number of iterations used and a boolean indicating whether or not the simulation converged.

Something like:

PyTree = typing.Any

@dataclasses.dataclass
class SolverResult:
  x: PyTree
  success: jnp.ndarray  # bool
  nit: jnp.ndarray  # int

  @property
  def status(self):
    return jnp.where(self.success, 0, self.nit)

  def __len__(self):
    return 2

  def __getitem__(self, key):
    return (self.x, self.status)[key]

  # pytree methods