Open mohamed82008 opened 2 years ago
Where would this be used?
Where would this be used?
As an optimisation to generalise ForwardOptimize
in ReverseDiff for example. If I have a function in a chain that I know has a sparse Jacobian with a known sparsity pattern, I can use Zygote/Diffractor to diff the whole chain except for this one function which can be made to use compressed ForwardDiff to apply the adjoint rule. This is common in PDE-related stuff where you have a lot of broadcasting-like operations done on each element followed by global operations (e.g. a linear/nonlinear system solve) followed by more element-wise or node-wise operations for post-processing, followed by more global operations (e.g. sum or average). In this case, it makes sense to wrap the element-wise/node-wise stuff in SparseJacFunction
and pass in the sparsity pattern.
Well, element-wise can always just us one partial forwards since J'v = Jv when diagonal. It would really only come into play if you know of a non-diagonal sparsity pattern where the column coloring is significantly small and reverse mode has a significant overhead on that type of function. I mean, we might as well implement it but that seems a little more rare, like something that could happen on specific sparsity patterns of a nonlinear system of PDEs that tend to use a lot of scalar indexing and no linear algebra.
It would really only come into play if you know of a non-diagonal sparsity pattern where the column coloring is significantly small and reverse mode has a significant overhead on that type of function.
That's usually the case whenever you have node-element interactions in the broadcasting-like operations. It's "finite element"-wise not element-wise as in one number at a time.
For that case, wouldn't you just use Enzyme in the kernels?
Perhaps I should. But ForwardDiff-based implementations should also be competitive.
Yeah, might as well benchmark it.
Would be great to define a struct that wraps a vector-valued function with a vector input such that when constructing an instance, the colouring is done once using the user-input sparsity pattern. Then an frule and rrule can be defined for this function using ChainRulesCore. This function struct will then be automatically usable in something like Nonconvex or GalacticOptim.