jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30k stars 2.75k forks source link

feature request: sparse jacobian and sparse hessians #1032

Open martinResearch opened 5 years ago

martinResearch commented 5 years ago

Some functions have sparse Jacobian or sparse Hessian and it can be usefull to obtain them as sparse matrices rather than accessing to the values through vector-jacobian or vector-hessian products functions: This would allow one to use easily sparse matrix factorizations methods such as sparse Cholesky. This could be usefull to easily formulate non linear least square problems and solve them efficiently using Levenberg-Marquardt for example. A possible approach to obtain such sparse Jacobian would be to use foward differentiation with the jacobians represented a sparse matrices at each step of the computation. An example of this approach I implemented in matlab can be found here. An implementation of this approach in python that supports only vectors can be found here. Is it something that could be implemented in Jax once there is support for sparse matrices ?

shoyer commented 5 years ago

I’m sure that these could be useful features, but indeed they will definitely be gated on support for sparse matrices — which unfortunately I would not expect anytime soon. Even then, I would not expect to see sparse Hessians/Jacobians if it requires rewriting every existing JVP rule.

That said, I’m not sure that these are strictly necessary, at least if you’re willing to accept using an iterative method for the linear solve. I’m pretty confident that Levenberg-Marquardt, for example, could be done entirely in a matrix free way by using Krylov methods like conjugate gradients for the solve. On a related note, I’ve been working on a Newton-Krylov solver leveraging JAX’s autograd that I hope to have up for review very soon.

martinResearch commented 5 years ago

Indeed one could Krylov methods. I just found out the scipy sparse iterative linear solvers support providing the sparse matrices as linear operators, which would make that approach quite easy to use.

shoyer commented 5 years ago

Yes, we’ve written experimental versions of GMRES and CG in JAX.

ghost commented 4 years ago

Any chance something like this library has been or can be used with jax?

shoyer commented 4 years ago

@jamesthegiantpeach not at present, no. That's a little off-topic for this issue so I responded in #765.

cisprague commented 3 years ago

An approach I used to get a sparse representation is this:

# arbitrary function and its Jacobian
function = jit(lambda x: ...)
jacobian = jit(jacfwd(function))

# estimate sparsity with random input
x = random(...)
sparse_ids = np.vstack((np.nonzero(jacobian(x))))
sparse_jacobian = jit(lambda x: jacobian(x)[[*sparse_ids]])

Is there a better way to do this? Couldn't the sparsity pattern be inferred from jacobian without having to sample it?

tetterl commented 3 years ago

Is there any chance of sparse Jacobian's with the upcoming sparse support (https://github.com/google/jax/issues/765, https://github.com/google/jax/pull/4422)? Or is there currently an efficient way to obtain sparse Jacobian's (any scipy format is fine) in JAX?

For badly conditioned problems (e.g. CGNR - Conjugate Gradient Method on the Normal Equations) using iterative solvers is rather slow and using direct solvers could improve the performance substantially.

trologat commented 2 years ago

@cisprague I think a similar approach as in Julia could be used to infer the sparsity pattern (Structure and Sparsity Detection, https://github.com/JuliaDiff/SparseDiffTools.jl). This could be achieved by writing a non-standard interpreter (or transformation?) as in Sparsity Programming: Automated Sparsity-Aware Optimizations in Differentiable Programming.

Even then, I would not expect to see sparse Hessians/Jacobians if it requires rewriting every existing JVP rule.

@shoyer Rewriting every existing JVP is actually not necessary since a similar approach as in https://github.com/JuliaDiff/SparseDiffTools.jl#colorvec-assisted-differentiation could be taken if we could get a matrix coloring from the sparsity pattern mentioned above. We could simply use jvp and hvp to compute a sparse Jacobian and Hessian since we also have some support for sparse matrices now. Adding some sparse solvers like cuSolver https://github.com/google/jax/issues/6544#issuecomment-930483553 could make JAX a compelling library for classical optimization problems.

shoyer commented 2 years ago

I'd love to see experimentation with things like automatic sparsity detection, though it's still unclear to me how well that would fit into JAd itself vs an add-on package.

CuSolver bindings do feel like something that could definitely belong inside JAX, so if you're interested in that, I would definitely encourage to give it a shot!

trologat commented 2 years ago

Considering that we would need some efficient matrix coloring this might be something that should go in a separate library. On the other side I'm not sure if all the APIs for non-standard interpreter / transformations are exposed to infer the sparsity pattern?

Unfortunately, I don't have a instant need nor time to experiment with those things. It would be absolutely cool to have those things though. No more writing of Jacobians for classical optimization sounds like great thing to me :-)

mfschubert commented 2 years ago

I have some use for sparse Jacobian computation, and created a small module which more efficiently computes the Jacobian when the sparsity is known beforehand. It might be relevant/interesting in this context. It relies on the networkx library to handle the coloring problem.

https://github.com/mfschubert/sparsejac

martinResearch commented 2 years ago

Also of interest, this library that supports forward differentiation with sparse Jacobians. Did not try it yet.

https://github.com/PTNobel/AutoDiff

mfschubert commented 2 years ago

I extended sparsejac to now support building Jacobians in forward mode, which may be the better choice depending upon the Jacobian sparsity.

I am skeptical that automatic sparsity detection can be robust in a general case, but it would be nice to handle this for the user if it can be done reliably.

martinResearch commented 2 years ago

@mfschubert , I would be interested in understanding better why you are skeptical that automatic sparsity detection can be robust. Do you have in mind some static analysis of the code to figure out the sparsity pattern ? If I understand well the approach you have followed consists in defining the matrix sparsity before the differentiation step, and then use matrix coloring to find a strategy that allows to estimate the non zero values of the matrix with in a minimum matrix-vector products in forward mode or vector-matrix products in reverse mode.

In contrast, the method I implemented in my matlab AutoDiff library (https://github.com/martinResearch/MatlabAutoDiff) does not require such a preliminary sparsity detection step. The forward derivatives with respect to all the inputs are computed in a single forward pass using matrix-matrix products at each step of the forward computation (instead of matrix-vector products used when differentiating with respect one input at a time). Using sparse matrices these matrix-matrix products remain sparse and thus the sparsity of the jacobian is obtained automatically as a result of sparse matrices products. In some way this approach can be interpreted as computing simultaneously both the jacobian sparsity structure and the non-zero values in a single forward pass. I believe the method implemented in https://github.com/PTNobel/AutoDiff is very similar to that (I would need to look at the source code to check that is true).

mfschubert commented 2 years ago

@martinResearch I was actually thinking of the sparsity estimation as proposed above by @cisprague, which if robust would be straightforward to incorporate. I haven't thought through alternatives, but what you've outlined seems like it could be promising.

shoyer commented 2 years ago

In contrast, the method I implemented in my matlab AutoDiff library (https://github.com/martinResearch/MatlabAutoDiff) does not require such a preliminary sparsity detection step. The forward derivatives with respect to all the inputs are computed in a single forward pass using matrix-matrix products at each step of the forward computation (instead of matrix-vector products used when differentiating with respect one input at a time). Using sparse matrices these matrix-matrix products remain sparse and thus the sparsity of the jacobian is obtained automatically as a result of sparse matrices products. In some way this approach can be interpreted as computing simultaneously both the jacobian sparsity structure and the non-zero values in a single forward pass.

I'm not sure I follow how storing sparse Jacobians avoids the sparsity detection problem.

Even with sparse Jacobians, you have to pick the order in which to evaluate matrix-matrix products, which could mean either forward mode or reverse mode or something in between.

martinResearch commented 2 years ago

you have to pick the order in which to evaluate matrix-matrix products, which could mean either forward mode or reverse mode or something in between.

That is true and my matlab implementation is using forward differentiation which is easy to implement as it does not require to store the compute graph. As implemented, it does not use a preliminary pass through the whole function being differentiated before computing the Jacobian values and it does not use graph coloring. It does not assume the sparsity pattern to be the same for any input. The jacobian could thus potentially be dense for some input and sparse for some other output. And yet it exploits sparsity to be faster than dense jacobian estimation when it is sparse enough. This approach is described in section 5.3 in this tomlab mad documentation https://tomopt.com/docs/TOMLAB_MAD.pdf#page26

Maybe I misunderstood what is referred to as "sparsity detection" here. If I understand the method proposed by @cisprague correctly it requires to compute the full jacobian density at least once with a random input and then assumes that the coefficients that are zeros in the jacobian will remain zeros for other inputs. Intuitively using random inputs seems to reduces the risk of having false negatives (assuming non zeros as positive class and zero as negative class), due to terms that "unluckily" cancel out, but I am not sure how to formalise that. It seems that if the sparsity pattern is actually input dependant (if one use piecewise polynomial functions for example) this will fail. As an example the function f(x,y)=Max(0,x*x) + y is differentiable everywhere and its Jacobian is [2x,1] if x>0 and [0,1] if x<=0. If x is negative in the random input used to guess the sparsity pattern, then the sparsity pattern will then be wrong with too many zeros when x>0.

shoyer commented 2 years ago

OK, I think was just confused about terminology.

peterdsharpe commented 2 years ago

@mfschubert + JAX devs,

How would folks feel about merging sparsejac as a PR into jax.experimental.sparse? It's an elegant and usable bit of code as-is, and having this in the JAX library would provide a common interface for those looking to hack in this area in the future (e.g., to try adding automated sparsity detection).

This introduces networkx as a new dependency; not sure how devs feel about that.

jakevdp commented 2 years ago

Thanks @peterdsharpe - I would be hesitant to add a networkx dependency to JAX, even if it's just optional. Aditionally, I worry about making this part of jax.experimental because due to its use of networkx, the algorithm is not compatible with JIT and other JAX transforms unless the sparsity is concrete. I wonder how hard it would be to write a JAX-compatible graph coloring utility?

martinResearch commented 2 years ago

the least_squares function in scipy can take jac_sparsity in order to accelerate the estimation of the jacobian with finite difference through fewer calls to the function. It is using the greedy method to group the columns so that two columns from two different group have non overlapping supports (equivalent to graph coloring ? ). The code to do that in scipy is here https://github.com/scipy/scipy/blob/8a64c938ddf1ae4c02a08d2c5e38daeb8d061d38/scipy/optimize/_group_columns.py#L59. Not sure how efficient that method is. Maybe it could be used as reference for a jax reimplementation ?

yunlongxu-numagic commented 2 years ago

@mfschubert thanks for the sparsejac work, I'm so glad that I found it! I tried it and it did bring some significant improvement in certain cases.

Do you know what would it take to make the interface more similar to JAX's jacfwd and jacrev? One thing in particular I find missing is the ability to pass in more than one inputs, but only take derivatives with one, eg:

def g(x, vector_input, dict_input):
    pass

# Jacobian only w.r.t. the first input
dg_dx = jax.jacfwd(g)

# But dg_dx is still a function of other inputs, and works with many types of inputs, thanks to Pytree
partial_jac = dg_dx(x, vector_input, dict_Input)

I wonder (only a speculation), does all those high level interfaces of JAX autograd rely on some simple vector-in-vector out function interface on the inside? If so, then does it feel like sparsejac is something that ought to be baked into the low-level API rather than directly used by user?

mfschubert commented 2 years ago

@yunlongxu-numagic Glad that it is helpful! I actually just made some changes which allow some of the functionality you are looking for (I believe). You can now select the argument you wish to differentiate with respect to, although presently you may only differentiate with respect to a single argument. I also added support for has_aux.

Going forward, I think it would be useful to support differentiation with respect to multiple arguments and pytrees, rather than just single rank-1 arrays. This will be a bit more involved, however.