sethaxen / ExponentialAction.jl

Compute the action of the matrix exponential
MIT License
24 stars 0 forks source link

Enzyme support #26

Open sethaxen opened 8 months ago

sethaxen commented 8 months ago

I just checked, and on Julia v1.10.0, Enzyme v0.11.15 can differentiate expv with no troubles. It fails on expv_sequence though. Here's the code:

using Enzyme, ExponentialAction

f(t, A, v) = sum(ExponentialAction.expv(t, A, v))
fseq(tmin, tmax, A, v) = sum(sum, ExponentialAction.expv_sequence(range(tmin, tmax, 10), A, v))
function grad_enzyme(t, A, v)
    dA = fill!(similar(A), 0)
    dv = fill!(similar(v), 0)
    (dt,), = autodiff(Reverse, f, Active, Active(t), Duplicated(A, dA), Duplicated(v, dv))
    return dt, dA, dv
end
function grad_enzyme_seq(tmin, tmax, A, v)
    dA = fill!(similar(A), 0)
    dv = fill!(similar(v), 0)
    (dtmin, dtmax), = autodiff(Reverse, fseq, Active, Active(tmin), Active(tmax), Duplicated(A, dA), Duplicated(v, dv))
    return dtmin, dtmax, dA, dv
end

tmin, tmax = sort(rand(2))
A = randn(30, 30)
v = randn(size(A, 2))

grad_enzyme(tmin, A, v)  # fine
grad_enzyme_seq(tmin, tmax, A, v)  # errors

Enzyme succeeded on expv for every structured matrix I tried in LinearAlgebra but failed for SparseMatrixCSC.

Unlike the other frameworks, we don't have many options to improve things with Enzyme. I think in principle Enzyme's activity analysis would avoid differentiating through code used only for control flow. But we may need to open an issue on Enzyme for the expv_sequence error if we can simplify it.

FWIW, Enzyme still errors on ExponentialUtilities.expv