JuliaDiff / SparseDiffTools.jl

Fast jacobian computation through sparsity exploitation and matrix coloring
MIT License
243 stars 42 forks source link

Automated sparsity detection #10

Closed ChrisRackauckas closed 5 years ago

ChrisRackauckas commented 5 years ago

@shashi got the first round going. From talks with @vchuravy and all, it seems like there are a few ways to go about this:

So I think the API on sparsity should have a dispatch between the different methods since they all have trade-offs. sparsity!(f!, Y, X, method=TraceGraph(),S=Sparsity(length(Y), length(X))). A last method could be a DefaultDectection() that works like:

As for tests, the test equations would use (du,u)->f(du,u,p,t) from definitions of differential equations. It would also be nice if there's just a dispatch on DEProblem that remakes the problem with sparse matrix support, but let's leave that for later. The Lorenz equation is good for unit tests: http://docs.juliadiffeq.org/latest/tutorials/ode_example.html#Example-2:-Solving-Systems-of-Equations-1 . For a bigger example, https://github.com/JuliaDiffEq/DiffEqProblemLibrary.jl has a few. The Bruss problem is particularly interesting

https://github.com/JuliaDiffEq/DiffEqProblemLibrary.jl/blob/master/src/ode/brusselator_prob.jl

since it matches a lot of things we'd typically see in GPU-based PDE code. The different forms of the PDE from http://juliadiffeq.org/DiffEqTutorials.jl/html/introduction/optimizing_diffeq_code.html also are interesting and we should make sure we support the optimized and unoptimized forms well.

vchuravy commented 5 years ago

Branch elimination. Essentially it's what Shashi's PR was, except at every branch, you just eliminate the branch and inline both sides of it.

The inline both sides of it is non-trivial if you want to execute it as the Cassette based tool from Shashi needs, take the simple example where one branch throws an error.

Shashi and I discussed looking more into Abstract Interpretation maybe based on https://github.com/JuliaDebug/TypedCodeUtils.jl, which would allow you to do this just using reflection, but has the strong requirement for statically typed code.

shashi commented 5 years ago

Yeah we can't just run both sides of a branch, we will need to resort to type inferred expansion at that point, or Concolic fuzzing or a combination of both. I'll work with some contrived examples.

Another idea I'm thinking of that may not need SAT solver if it turns out to be workable for Chris is to keep track of branches taken in the code, e.g. just a bit vector or a tuple of bools (even as a type parameter if that's usually beneficial for compilation and restart...), this can actually uniquely determine a code path. It can be tracked using a Cassette @pass. (this path may become too big if there's an if statement affecting the sparsity in a while loop, but hopefully that's not a going to be a big problem, those can be detected and dealt with differently). So we can have the graph coloring found in the first run be contingent on hitting the same code path (branch bit vector) when applied. If run 2 starts to deviate from the currently known paths, then we can start re-computing Sparsity from that point on (possibly using FunctionalCollections to store all the sparsity vectors), combine it with all known sparsities from the past and do some kind of incremental graph coloring update (there seem to be a bunch of algorithms for this)...

vchuravy commented 5 years ago

Another idea I'm thinking of that may not need SAT solver if it turns out to be workable for Chris is to keep track of branches taken in the code, e.g. just a bit vector or a tuple of bools (even as a type parameter if that's usually beneficial for compilation and restart...), this can actually uniquely determine a code path.

Oh this is interesting, you could literally use Cassette to rewrite the branches to just be true, and false, therefore indeed removing the need to use a SAT solver.

shashi commented 5 years ago

you could literally use Cassette to rewrite the branches to just be true, and false

Oh I'm not suggesting that kind of rewrite, also that'd be wrong right? for example:

x < 0 ? sqrt(abs(x)) : sqrt(x)

What I'm suggesting is incrementally updating the graph coloring as new inputs explore new paths.

vchuravy commented 5 years ago

Ah, yes you would hit branches that are not matching their conditions. The question is how to you generate new inputs?

shashi commented 5 years ago

I'm guessing that we don't need to, the gradient algorithm will call it with the inputs it requires, and those may trigger different paths, if they don't that's even better. :) So the world view to take is not "we have figured out a conservative estimate of the sparsity pattern", but "we have a coloring that works for this code path, but we're willing to update it if there's a new code path that gets accessed". If that's possible.

vchuravy commented 5 years ago

Yeah, if we don't need to precompute all paths, but that might cause us to be overly optimistic.