SciML / ExponentialUtilities.jl

Fast and differentiable implementations of matrix exponentials, Krylov exponential matrix-vector multiplications ("expmv"), KIOPS, ExpoKit functions, and more. All your exponential needs in SciML form.
https://docs.sciml.ai/ExponentialUtilities/stable/
Other
93 stars 29 forks source link

Add ChainRules rules #40

Open sethaxen opened 4 years ago

sethaxen commented 4 years ago

From Slack: @sethaxen:

Does ExponentialUtilities.jl play well with AD packages, in particular Zygote?

@ChrisRackauckas:

not fully with Zygote it'll need adjoints since it's doing a lot of scalar stuff it's writing the kernels directly the adjoints are easy though

I in particular need adjoints for expv. Zygote currently has an adjoint rule for exp(::AbstractMatrix) and exp(::Hermitian) using the eigendecomposition. I imagine though there's a better way to implement the adjoints for expv by looking at the underlying algorithm (I have not).

ChrisRackauckas commented 4 years ago

There should be ways to do this without defining the Jacobian. expv is the solution to the linear ODE, so the adjoint of the ODE can be should be able to be used to derive the expression in terms of the adjoint, which IIRC should just be:

du = Au -> dlambda -> lambda'*A

which means the adjoint should just be expv(A,delta').

sethaxen commented 4 years ago

For λ = expv(t, A, u) and adjoint of λ, Δλ, the adjoint of u should I think be ∂u = expv(t, A', Δλ), which is quite nice. We also need the adjoints ∂t and ∂A, which will take more thought.

sethaxen commented 4 years ago

Especially since A doesn't even need to be a matrix, right? I don't think we'll be able to support all types of A for a custom adjoint, just AbstractMatrixes.

ChrisRackauckas commented 4 years ago

Yeah, the difficult thing will be supporting something that's not concrete, since then it can't adjoint. But then that's just defined as the reverse mode of the function f(u) = A*u, so I think it can work out, it'll just be more complicated in code.

Those again would come from this derivation. You might want to read https://diffeq.sciml.ai/stable/extras/sensitivity_math/ or the supplemental of https://arxiv.org/abs/2001.04385 . Specifically, the ∂A term is given by an integral over the Legrange multiplier term. Coincidentally, the phiv values used in the exponential integrators are these integrals, so the adjoint can probably be written as just a calculation of phi_1. I think it's like phiv(t, A', Δλ) + reversemode(A) kind of thing (in pseudocode, off the top of my head so maybe missing a detail somewhere).

∂t is easy in this interpretation: λ = expv(t, A, u) = exp(tA)u is equal to λ=Aλ where λ(0)=u and solve to t, so the derivative of the solution w.r.t. t is just A (or in reverse-mode, maybe A').

Again, all might be missing a detail since I'm doing it quickly, but that should be the gist of it.

sethaxen commented 4 years ago

Thanks! That should be enough to get me started. I'll probably tackle this in a few months if no one else does before then (unless I find some time early).

sethaxen commented 3 years ago

Working on this now and have some follow-up questions.

Specifically, the ∂A term is given by an integral over the Legrange multiplier term. Coincidentally, the phiv values used in the exponential integrators are these integrals, so the adjoint can probably be written as just a calculation of phi_1. I think it's like phiv(t, A', Δλ) + reversemode(A) kind of thing (in pseudocode, off the top of my head so maybe missing a detail somewhere).

I've spent some time working through the provided references and still haven't yet comprehended this comment. What is reversemode(A) here? By phiv(t, A', Δλ) do you mean phiv(t, A', Δλ, 1)[:, 2], which computes I believe \phi_1(A') Δλ? This would compute an adjoint of the same dimension as v, not a matrix.

ChrisRackauckas commented 3 years ago

Hmm, I guess it doesn't use the phi_1. It is the first integral of the term so I'm a little surprised it doesn't show up.

sethaxen commented 3 years ago

Okay, I think I worked something out for forward mode at least. The pushforward of u = expv(t, A, u_0) is (using slide 5 of http://www1.maths.leeds.ac.uk/~jitse/scicade09.pdf): Δu = A \phi_0(tA) u_0 Δt + (\phi_0(tA) Δu_0 + \sum_{i=1}^\infty t^i \phi_i(tA) ΔA A^(i-1) u_0), the part in parentheses being the solution to the ODE Δu′ = A Δu + ΔA u. Perhaps there's some way to simplify that hideous sum term. Still need to work out the corresponding reverse mode.

sethaxen commented 3 years ago

Following up on @ChrisRackauckas's point, we can indeed compute the adjoint of A by solving an ODE in reverse. A working prototype here: https://gist.github.com/sethaxen/4071b401b9b4ff4f5421136cec2fa7da/dd914b79d465d8653b1674cbc466f5a29d95fbae#file-expv_chainrules-jl-L64-L77

I haven't worked out how to solve this ODE using just the functions in this package; currently I require OrdinaryDiffEq. This does what I need to right now, so I'll put #51 on hold until I work out something efficient I can do using just this package.

sethaxen commented 3 years ago

Another way to compute the adjoint of A comes from https://doi.org/10.1109/TAC.1978.1101743. Let w = expv(t, A, v), Δw be the adjoint of w, and ∂v = expv(t, A', Δw) be the pulled back adjoint of v. The adjoint of A is the solution to the integral int_0^t exp(s A') Δw w' exp(-s A') ds. Define the block-triangular matrix D = [-A' ∂v*w'; zero(A) -A']. Then the upper right block of exp(t * D) is the adjoint of A. This is fine for small dense A but is otherwise very inefficient, so this doesn't seem useful.

sethaxen commented 3 years ago

Here's where I landed on this. The adjoint for A will be computed by hand-deriving the pullback through exp and arnoldi/lanczos, The former will be added to ChainRules (https://github.com/JuliaDiff/ChainRules.jl/issues/331). I locally have an implementation of the latter that requires no checkpointing.

For matrix n × n A, the final step of the pullback for arnoldi is the product of an n × m matrix and the adjoint of another n × m matrix, where m is the dimension of the Krylov subspace. For dense A, this is just a matmul, but for huge sparse A, we would need to know its sparsity pattern to avoid creating a huge dense matrix and instead only compute certain dot products of columns.

We need a function like outer_sparse!(∂A, x::AbstractVecOrMat, y::AbstractVecOrMat), where ∂A is a differential type of A (either Composite{typeof(A)} or an AbstractMatrix) that does this. We can implement such a function for all AbstractMatrix types in base Julia and define the rrule only for those types, wrapping an expv_rev that has no type constraints. Then an implementer of a custom operator can overload outer_sparse! for their operator and define an rrule wrapping expv_rev. Unfortunately this would require the array package to require ExponentialUtilities or a user to commit type piracy.

stevengj commented 8 months ago

(Note that this block-triangular rule is a special case of an algorithm to differentiate matrix functions by Mathias in 1996, as discussed in https://github.com/JuliaDiff/ChainRules.jl/issues/764)