JuliaDecisionFocusedLearning / ImplicitDifferentiation.jl

Automatic differentiation of implicit functions
https://juliadecisionfocusedlearning.github.io/ImplicitDifferentiation.jl/
MIT License
123 stars 6 forks source link

Direct linear solver not efficiently handled #59

Closed mohamed82008 closed 1 year ago

mohamed82008 commented 1 year ago

https://github.com/gdalle/ImplicitDifferentiation.jl/blob/7f9fa8fd5b3851fd55421501b1f061f784c68ab0/ext/ImplicitDifferentiationChainRulesExt.jl#L63

Currently to use a direct linear solver, I need to do pass in (A, b) -> (Matrix(A) \ b, (solved=true,)). There are 2 problems with that. First it is inconvenient to pass it in like this. Second, it is inefficient because it will call Matrix(A) and factorise it every time the pullback function is called. Calling Matrix(A) and factorising it can be expensive and is only needed to be done once and then reused for every pullback call. This currently makes Jacobian computations with direct solvers much slower than necessary.

gdalle commented 1 year ago

The same goes for the forward rule with dual numbers, we have n pushforwards with the same matrix in the solve and different target vectors.

Should we add some form of factorization caching?

mohamed82008 commented 1 year ago

The easiest way would be to have a separate presolve function which defaults to identity but if the solver is a direct solver, it becomes A -> lu(Matrix(A)). This means we need to offer a function direct which users can pass linear_solver = direct and then we check presolve = linear_solver === direct ? (A -> lu(Matrix(A))) : identity.

mohamed82008 commented 1 year ago

In theory, presolve can also specialise the pushforward or pullback operators for the functions at hand, but these are more advanced use cases.

gdalle commented 1 year ago

Maybe there is a presolve step in the Krylov solvers too that we could exploit

mohamed82008 commented 1 year ago

I have a draft PR locally. I will clean it up in the next couple of weeks.

gdalle commented 1 year ago

In the meantime maybe I could set up package benchmarks, so that we see the improvement?

mohamed82008 commented 1 year ago

yes please

mohamed82008 commented 1 year ago

Maybe there is a presolve step in the Krylov solvers too that we could exploit

In theory, we might be able to figure out a good preconditioner upfront.