Open sethaxen opened 4 years ago
There should be ways to do this without defining the Jacobian. expv
is the solution to the linear ODE, so the adjoint of the ODE can be should be able to be used to derive the expression in terms of the adjoint, which IIRC should just be:
du = Au -> dlambda -> lambda'*A
which means the adjoint should just be expv(A,delta').
For λ = expv(t, A, u)
and adjoint of λ
, Δλ
, the adjoint of u
should I think be ∂u = expv(t, A', Δλ)
, which is quite nice.
We also need the adjoints ∂t
and ∂A
, which will take more thought.
Especially since A
doesn't even need to be a matrix, right? I don't think we'll be able to support all types of A
for a custom adjoint, just AbstractMatrix
es.
Yeah, the difficult thing will be supporting something that's not concrete, since then it can't adjoint. But then that's just defined as the reverse mode of the function f(u) = A*u
, so I think it can work out, it'll just be more complicated in code.
Those again would come from this derivation. You might want to read https://diffeq.sciml.ai/stable/extras/sensitivity_math/ or the supplemental of https://arxiv.org/abs/2001.04385 . Specifically, the ∂A
term is given by an integral over the Legrange multiplier term. Coincidentally, the phiv
values used in the exponential integrators are these integrals, so the adjoint can probably be written as just a calculation of phi_1
. I think it's like phiv(t, A', Δλ) + reversemode(A) kind of thing (in pseudocode, off the top of my head so maybe missing a detail somewhere).
∂t is easy in this interpretation: λ = expv(t, A, u) = exp(tA)u is equal to λ=Aλ where λ(0)=u and solve to t
, so the derivative of the solution w.r.t. t
is just A
(or in reverse-mode, maybe A'
).
Again, all might be missing a detail since I'm doing it quickly, but that should be the gist of it.
Thanks! That should be enough to get me started. I'll probably tackle this in a few months if no one else does before then (unless I find some time early).
Working on this now and have some follow-up questions.
Specifically, the
∂A
term is given by an integral over the Legrange multiplier term. Coincidentally, thephiv
values used in the exponential integrators are these integrals, so the adjoint can probably be written as just a calculation ofphi_1
. I think it's like phiv(t, A', Δλ) + reversemode(A) kind of thing (in pseudocode, off the top of my head so maybe missing a detail somewhere).
I've spent some time working through the provided references and still haven't yet comprehended this comment. What is reversemode(A)
here? By phiv(t, A', Δλ)
do you mean phiv(t, A', Δλ, 1)[:, 2]
, which computes I believe \phi_1(A') Δλ
? This would compute an adjoint of the same dimension as v
, not a matrix.
Hmm, I guess it doesn't use the phi_1. It is the first integral of the term so I'm a little surprised it doesn't show up.
Okay, I think I worked something out for forward mode at least. The pushforward of u = expv(t, A, u_0)
is (using slide 5 of http://www1.maths.leeds.ac.uk/~jitse/scicade09.pdf):
Δu = A \phi_0(tA) u_0 Δt + (\phi_0(tA) Δu_0 + \sum_{i=1}^\infty t^i \phi_i(tA) ΔA A^(i-1) u_0)
,
the part in parentheses being the solution to the ODE
Δu′ = A Δu + ΔA u
.
Perhaps there's some way to simplify that hideous sum term. Still need to work out the corresponding reverse mode.
Following up on @ChrisRackauckas's point, we can indeed compute the adjoint of A
by solving an ODE in reverse. A working prototype here:
https://gist.github.com/sethaxen/4071b401b9b4ff4f5421136cec2fa7da/dd914b79d465d8653b1674cbc466f5a29d95fbae#file-expv_chainrules-jl-L64-L77
I haven't worked out how to solve this ODE using just the functions in this package; currently I require OrdinaryDiffEq. This does what I need to right now, so I'll put #51 on hold until I work out something efficient I can do using just this package.
Another way to compute the adjoint of A
comes from https://doi.org/10.1109/TAC.1978.1101743.
Let w = expv(t, A, v)
, Δw
be the adjoint of w
, and ∂v = expv(t, A', Δw)
be the pulled back adjoint of v
.
The adjoint of A
is the solution to the integral int_0^t exp(s A') Δw w' exp(-s A') ds
.
Define the block-triangular matrix D = [-A' ∂v*w'; zero(A) -A']
. Then the upper right block of exp(t * D)
is the adjoint of A
. This is fine for small dense A
but is otherwise very inefficient, so this doesn't seem useful.
Here's where I landed on this. The adjoint for A
will be computed by hand-deriving the pullback through exp
and arnoldi
/lanczos
, The former will be added to ChainRules (https://github.com/JuliaDiff/ChainRules.jl/issues/331). I locally have an implementation of the latter that requires no checkpointing.
For matrix n × n
A
, the final step of the pullback for arnoldi
is the product of an n × m
matrix and the adjoint of another n × m
matrix, where m
is the dimension of the Krylov subspace. For dense A
, this is just a matmul, but for huge sparse A
, we would need to know its sparsity pattern to avoid creating a huge dense matrix and instead only compute certain dot products of columns.
We need a function like outer_sparse!(∂A, x::AbstractVecOrMat, y::AbstractVecOrMat)
, where ∂A
is a differential type of A
(either Composite{typeof(A)}
or an AbstractMatrix
) that does this. We can implement such a function for all AbstractMatrix
types in base Julia and define the rrule
only for those types, wrapping an expv_rev
that has no type constraints. Then an implementer of a custom operator can overload outer_sparse!
for their operator and define an rrule
wrapping expv_rev
. Unfortunately this would require the array package to require ExponentialUtilities or a user to commit type piracy.
(Note that this block-triangular rule is a special case of an algorithm to differentiate matrix functions by Mathias in 1996, as discussed in https://github.com/JuliaDiff/ChainRules.jl/issues/764)
From Slack: @sethaxen:
@ChrisRackauckas:
I in particular need adjoints for
expv
. Zygote currently has an adjoint rule forexp(::AbstractMatrix)
andexp(::Hermitian)
using the eigendecomposition. I imagine though there's a better way to implement the adjoints forexpv
by looking at the underlying algorithm (I have not).