RalphAS / GenericSchur.jl

Julia package for Schur decomposition of matrices with generic element types
Other
27 stars 4 forks source link

Use of ForwardDiff through `schur` #10

Closed baggepinnen closed 1 year ago

baggepinnen commented 1 year ago

Hello there 👋

I am wondering if it would make sense to allow ForwardDiff.Dual numbers through the schur factorization in this package? Currently, many methods are restricted to

StridedMatrix{Complex{T}} where T<:AbstractFloat

which prevents the use of ForwardDiff.Dual since they are not <:AbstractFloat.

RalphAS commented 1 year ago

My understanding is that differentiation of spectral decompositions is more reliable, accurate, and (usually) efficient if implemented as independent algorithms based on analysis. This is done in ChainRules for eigen, but not yet for schur. (I've heard of recent analytic work on the Schur adjoint but I haven't had the chance to study it.)

Many of the methods here depend on details of floating point arithmetic for edge cases, in an effort to emulate the reliability of LAPACK. I would be surprised if those sections are good fits for Dual. In retrospect perhaps I should not have put "Generic" in the package name.

That being said, the GenericLinearAlgebra package has less precise but more generic implementations. If one really wants naive autodiff, that would be a better avenue.

In fact, an important advantage of the <:AbstractFloat restriction here is compatibility with GenericLinearAlgebra, so a Dual derivative from the latter should even work with the specializations provided here (except in edge cases where perturbations are likely bizarre anyway).

mohamed82008 commented 1 year ago

You may want to check DifferentiableFactorizations.jl. I have Schur decompositions supported there via ImplicitDifferentiation. ID supports both forward and reverse mode AD using the implicit function theorem to avoid passing through the solver completely.

baggepinnen commented 1 year ago

Thanks for your comments and insight @RalphAS

I have tried ImplicitDifferentiation.jl and it worked. It was 30x slower than finite differences due to a lot of type instabilities related to Duals, this line to be specific does not fully specify the Dual type https://github.com/gdalle/ImplicitDifferentiation.jl/blob/main/ext/ImplicitDifferentiationForwardDiffExt.jl#L42

I think that we will pursue another approach for my particular problem where we specify the diff rule for a higher-level function instead.