Open martinResearch opened 5 years ago
I’m sure that these could be useful features, but indeed they will definitely be gated on support for sparse matrices — which unfortunately I would not expect anytime soon. Even then, I would not expect to see sparse Hessians/Jacobians if it requires rewriting every existing JVP rule.
That said, I’m not sure that these are strictly necessary, at least if you’re willing to accept using an iterative method for the linear solve. I’m pretty confident that Levenberg-Marquardt, for example, could be done entirely in a matrix free way by using Krylov methods like conjugate gradients for the solve. On a related note, I’ve been working on a Newton-Krylov solver leveraging JAX’s autograd that I hope to have up for review very soon.
Indeed one could Krylov methods. I just found out the scipy sparse iterative linear solvers support providing the sparse matrices as linear operators, which would make that approach quite easy to use.
Yes, we’ve written experimental versions of GMRES and CG in JAX.
Any chance something like this library has been or can be used with jax?
@jamesthegiantpeach not at present, no. That's a little off-topic for this issue so I responded in #765.
An approach I used to get a sparse representation is this:
# arbitrary function and its Jacobian
function = jit(lambda x: ...)
jacobian = jit(jacfwd(function))
# estimate sparsity with random input
x = random(...)
sparse_ids = np.vstack((np.nonzero(jacobian(x))))
sparse_jacobian = jit(lambda x: jacobian(x)[[*sparse_ids]])
Is there a better way to do this? Couldn't the sparsity pattern be inferred from jacobian
without having to sample it?
Is there any chance of sparse Jacobian's with the upcoming sparse support (https://github.com/google/jax/issues/765, https://github.com/google/jax/pull/4422)? Or is there currently an efficient way to obtain sparse Jacobian's (any scipy format is fine) in JAX?
For badly conditioned problems (e.g. CGNR - Conjugate Gradient Method on the Normal Equations) using iterative solvers is rather slow and using direct solvers could improve the performance substantially.
@cisprague I think a similar approach as in Julia could be used to infer the sparsity pattern (Structure and Sparsity Detection, https://github.com/JuliaDiff/SparseDiffTools.jl). This could be achieved by writing a non-standard interpreter (or transformation?) as in Sparsity Programming: Automated Sparsity-Aware Optimizations in Differentiable Programming.
Even then, I would not expect to see sparse Hessians/Jacobians if it requires rewriting every existing JVP rule.
@shoyer Rewriting every existing JVP is actually not necessary since a similar approach as in https://github.com/JuliaDiff/SparseDiffTools.jl#colorvec-assisted-differentiation could be taken if we could get a matrix coloring from the sparsity pattern mentioned above. We could simply use jvp
and hvp
to compute a sparse Jacobian and Hessian since we also have some support for sparse matrices now. Adding some sparse solvers like cuSolver https://github.com/google/jax/issues/6544#issuecomment-930483553 could make JAX a compelling library for classical optimization problems.
I'd love to see experimentation with things like automatic sparsity detection, though it's still unclear to me how well that would fit into JAd itself vs an add-on package.
CuSolver bindings do feel like something that could definitely belong inside JAX, so if you're interested in that, I would definitely encourage to give it a shot!
Considering that we would need some efficient matrix coloring this might be something that should go in a separate library. On the other side I'm not sure if all the APIs for non-standard interpreter / transformations are exposed to infer the sparsity pattern?
Unfortunately, I don't have a instant need nor time to experiment with those things. It would be absolutely cool to have those things though. No more writing of Jacobians for classical optimization sounds like great thing to me :-)
I have some use for sparse Jacobian computation, and created a small module which more efficiently computes the Jacobian when the sparsity is known beforehand. It might be relevant/interesting in this context. It relies on the networkx
library to handle the coloring problem.
Also of interest, this library that supports forward differentiation with sparse Jacobians. Did not try it yet.
I extended sparsejac
to now support building Jacobians in forward mode, which may be the better choice depending upon the Jacobian sparsity.
I am skeptical that automatic sparsity detection can be robust in a general case, but it would be nice to handle this for the user if it can be done reliably.
@mfschubert , I would be interested in understanding better why you are skeptical that automatic sparsity detection can be robust. Do you have in mind some static analysis of the code to figure out the sparsity pattern ? If I understand well the approach you have followed consists in defining the matrix sparsity before the differentiation step, and then use matrix coloring to find a strategy that allows to estimate the non zero values of the matrix with in a minimum matrix-vector products in forward mode or vector-matrix products in reverse mode.
In contrast, the method I implemented in my matlab AutoDiff library (https://github.com/martinResearch/MatlabAutoDiff) does not require such a preliminary sparsity detection step. The forward derivatives with respect to all the inputs are computed in a single forward pass using matrix-matrix products at each step of the forward computation (instead of matrix-vector products used when differentiating with respect one input at a time). Using sparse matrices these matrix-matrix products remain sparse and thus the sparsity of the jacobian is obtained automatically as a result of sparse matrices products. In some way this approach can be interpreted as computing simultaneously both the jacobian sparsity structure and the non-zero values in a single forward pass. I believe the method implemented in https://github.com/PTNobel/AutoDiff is very similar to that (I would need to look at the source code to check that is true).
@martinResearch I was actually thinking of the sparsity estimation as proposed above by @cisprague, which if robust would be straightforward to incorporate. I haven't thought through alternatives, but what you've outlined seems like it could be promising.
In contrast, the method I implemented in my matlab AutoDiff library (https://github.com/martinResearch/MatlabAutoDiff) does not require such a preliminary sparsity detection step. The forward derivatives with respect to all the inputs are computed in a single forward pass using matrix-matrix products at each step of the forward computation (instead of matrix-vector products used when differentiating with respect one input at a time). Using sparse matrices these matrix-matrix products remain sparse and thus the sparsity of the jacobian is obtained automatically as a result of sparse matrices products. In some way this approach can be interpreted as computing simultaneously both the jacobian sparsity structure and the non-zero values in a single forward pass.
I'm not sure I follow how storing sparse Jacobians avoids the sparsity detection problem.
Even with sparse Jacobians, you have to pick the order in which to evaluate matrix-matrix products, which could mean either forward mode or reverse mode or something in between.
you have to pick the order in which to evaluate matrix-matrix products, which could mean either forward mode or reverse mode or something in between.
That is true and my matlab implementation is using forward differentiation which is easy to implement as it does not require to store the compute graph. As implemented, it does not use a preliminary pass through the whole function being differentiated before computing the Jacobian values and it does not use graph coloring. It does not assume the sparsity pattern to be the same for any input. The jacobian could thus potentially be dense for some input and sparse for some other output. And yet it exploits sparsity to be faster than dense jacobian estimation when it is sparse enough. This approach is described in section 5.3 in this tomlab mad documentation https://tomopt.com/docs/TOMLAB_MAD.pdf#page26
Maybe I misunderstood what is referred to as "sparsity detection" here. If I understand the method proposed by @cisprague correctly it requires to compute the full jacobian density at least once with a random input and then assumes that the coefficients that are zeros in the jacobian will remain zeros for other inputs. Intuitively using random inputs seems to reduces the risk of having false negatives (assuming non zeros as positive class and zero as negative class), due to terms that "unluckily" cancel out, but I am not sure how to formalise that. It seems that if the sparsity pattern is actually input dependant (if one use piecewise polynomial functions for example) this will fail. As an example the function f(x,y)=Max(0,x*x) + y is differentiable everywhere and its Jacobian is [2x,1] if x>0 and [0,1] if x<=0. If x is negative in the random input used to guess the sparsity pattern, then the sparsity pattern will then be wrong with too many zeros when x>0.
OK, I think was just confused about terminology.
@mfschubert + JAX devs,
How would folks feel about merging sparsejac
as a PR into jax.experimental.sparse
? It's an elegant and usable bit of code as-is, and having this in the JAX library would provide a common interface for those looking to hack in this area in the future (e.g., to try adding automated sparsity detection).
This introduces networkx
as a new dependency; not sure how devs feel about that.
Thanks @peterdsharpe - I would be hesitant to add a networkx dependency to JAX, even if it's just optional. Aditionally, I worry about making this part of jax.experimental
because due to its use of networkx
, the algorithm is not compatible with JIT and other JAX transforms unless the sparsity is concrete. I wonder how hard it would be to write a JAX-compatible graph coloring utility?
the least_squares function in scipy can take jac_sparsity
in order to accelerate the estimation of the jacobian with finite difference through fewer calls to the function. It is using the greedy method to group the columns so that two columns from two different group have non overlapping supports (equivalent to graph coloring ? ). The code to do that in scipy is here https://github.com/scipy/scipy/blob/8a64c938ddf1ae4c02a08d2c5e38daeb8d061d38/scipy/optimize/_group_columns.py#L59. Not sure how efficient that method is. Maybe it could be used as reference for a jax reimplementation ?
@mfschubert thanks for the sparsejac
work, I'm so glad that I found it! I tried it and it did bring some significant improvement in certain cases.
Do you know what would it take to make the interface more similar to JAX's jacfwd
and jacrev
? One thing in particular I find missing is the ability to pass in more than one inputs, but only take derivatives with one, eg:
def g(x, vector_input, dict_input):
pass
# Jacobian only w.r.t. the first input
dg_dx = jax.jacfwd(g)
# But dg_dx is still a function of other inputs, and works with many types of inputs, thanks to Pytree
partial_jac = dg_dx(x, vector_input, dict_Input)
I wonder (only a speculation), does all those high level interfaces of JAX autograd rely on some simple vector-in-vector out function interface on the inside? If so, then does it feel like sparsejac
is something that ought to be baked into the low-level API rather than directly used by user?
@yunlongxu-numagic Glad that it is helpful! I actually just made some changes which allow some of the functionality you are looking for (I believe). You can now select the argument you wish to differentiate with respect to, although presently you may only differentiate with respect to a single argument. I also added support for has_aux
.
Going forward, I think it would be useful to support differentiation with respect to multiple arguments and pytrees, rather than just single rank-1 arrays. This will be a bit more involved, however.
Some functions have sparse Jacobian or sparse Hessian and it can be usefull to obtain them as sparse matrices rather than accessing to the values through vector-jacobian or vector-hessian products functions: This would allow one to use easily sparse matrix factorizations methods such as sparse Cholesky. This could be usefull to easily formulate non linear least square problems and solve them efficiently using Levenberg-Marquardt for example. A possible approach to obtain such sparse Jacobian would be to use foward differentiation with the jacobians represented a sparse matrices at each step of the computation. An example of this approach I implemented in matlab can be found here. An implementation of this approach in python that supports only vectors can be found here. Is it something that could be implemented in Jax once there is support for sparse matrices ?