Open norabelrose opened 3 years ago
Hi @norabelrose
This issue is a duplicate of #4322. I have opened a week ago PR #7825 to address exactly this issue.
The main issue right now is discussing with google's people what interface to implement: jax implements the standard scipy interface is to return a single number which is 0
if convergence happened and niter
if convergence failed.
Unfortuntaely this discards the number of steps taken to converge if convergence happens, which is unfortunate.
Scipy developers are aware of this issue tracked by https://github.com/scipy/scipy/issues/10474 , and are willing to accept a weird hack to always return the number of steps in a backwards-compatible manner, however someone should do that work.
The question is now what to do in Jax: implement the same (ugly) interface that scipy will implement to maintain scipy compatibility, or implement a named-tuple/dict design (like I sketched in my PR) which is not compatible but easier to use?
If some core dev was to answer this question, i could finalize my PR.
Thanks @PhilipVinc - I think you've hit the core of the issue: jax.scipy
by design implements the API of scipy. There is no doubt that the scipy API could be improved in places, but I'm not sure that the JAX mirror is the place to do that.
In terms of jax.scipy.optimize
in particular, this is one reason for the creation of the JAXopt package, which provides a new well-thought-out API for optimization in JAX. If you're looking for improved optimization APIs, we will likely not diverge much from scipy.optimize
within jax.scipy.optimize
, so I'd suggest taking a look at JAXopt.
And indeed, I think we're considering removing jax.scipy.optimize
in favor of jaxopt
(or just making it a thin shim that calls jaxopt
if you have it installed).
cc @shoyer and @mblondel for the move to jaxopt-backed optimize
.
I think my loose preference here would be to do an ugly backwards compatible work-around like that suggested in https://github.com/scipy/scipy/issues/10474#issuecomment-915249531. I wouldn't actually subclass from tuple, but we could certainly do unpacking/indexing like a pair of elements for backwards compatibility.
If we are going to go to the trouble of setting up a new module with its own calling conventions, then I think going to a class based interface like that in JAXopt could make a lot of sense for more flexibility. It's a lot more extensible to be able to call init/update methods from your own loops rather than to rely upon callbacks like SciPy, e.g., if you wanted to do autodiff through a solver step to optimize the parameters of a preconditioner.
I was not discussing scipy.optimize
at all (nor was @norabelrose) because I don't really know this part of the API: we were discussing scipy.sparse.linalg.[cg/bicgstab/..]
, though the arguments are probably similar.
I'm not even sure if scipy.sparse.linalg
fits within the realms of jaxOpt
(It seems to me it does not? though some solvers with a different API could be re-exported from there).
@shoyer It seems to me this approach is the one that has the most consensus? How would you go on about this? Would it entail defining a custom dataclass/PyTree that unpacks/indexes like a tuple but that actually contains a whole lot of extra information?
Could we put tentatively agree on what extra information to supply, if that is what you propose?
I would be in favor of following what scipy does, since jax.scipy is supposed to be a port of it.
Adding iterative linear system solvers such as CG to JAXopt could make sense at some point. I agree that they fit well in the init/update API that we have.
I'm not even sure if
scipy.sparse.linalg
fits within the realms ofjaxOpt
(It seems to me it does not? though some solvers with a different API could be re-exported from there).
I agree, it's not entirely clear that JAXopt is the right place to put this. Perhaps we need a new "JAX iterative solvers" library (like IterativeSolvers.jl) or could make a new sub-module jax.solvers
. I was more referring to the API design rather than where to actually put the code.
@shoyer It seems to me this approach is the one that has the most consensus? How would you go on about this? Would it entail defining a custom dataclass/PyTree that unpacks/indexes like a tuple but that actually contains a whole lot of extra information?
Yes, this was my thinking exactly. The extra fields should probably include at least the number of iterations used and a boolean indicating whether or not the simulation converged.
Something like:
PyTree = typing.Any
@dataclasses.dataclass
class SolverResult:
x: PyTree
success: jnp.ndarray # bool
nit: jnp.ndarray # int
@property
def status(self):
return jnp.where(self.success, 0, self.nit)
def __len__(self):
return 2
def __getitem__(self, key):
return (self.x, self.status)[key]
# pytree methods
My feature request is simple: Instead of always returning
None
in theinfo
field, the the matrix-free linear solvers injax.scipy.sparse.linalg
should return some sort of container object (i.e. a namedtuple or a dict) that containsconverged
,singular
, and/ornum_iter
fields which could be used to determine whether the algorithm actually converged.My use case is that I am currently implementing a boundary value problem solver in JAX using BiCGSTAB to find the Newton step direction, like this:
The solver actually appears to work quite well, except that I have no way of knowing when the Jacobian is singular. This makes the function essentially unusable for me, so I'm finding that I have to copy and paste the implementation of
jax.scipy.sparse.linalg.bicgstab
into my project and tweak it so that it returns the number of iterations. If the number of iterations is equal tomaxiter
, I assume it didn't converge. Clearly, though, this is suboptimal and should be relatively easy to fix. I would be willing to contribute the PR to implement this functionality if need be, but I wanted to post on here first to make sure there wasn't some major reason why this hasn't already been implemented.