EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
458 stars 65 forks source link

AD through ODE solution #1818

Closed rkurchin closed 2 months ago

rkurchin commented 2 months ago

Reposting from Slack at @wsmoses' request...

I’m trying to differentiate through an ODE solution with Enzyme and running into some errors. Basically, I started with the example on this docs page and just tried to Enzyme-ify it like so:

using SciMLSensitivity, OrdinaryDiffEq, Enzyme

function fiip(du, u, p, t)
    du[1] = dx = p[1] * u[1] - p[2] * u[1] * u[2]
    du[2] = dy = -p[3] * u[2] + p[4] * u[1] * u[2]
end
p = [1.5, 1.0, 3.0, 1.0];
u0 = [1.0; 1.0];
prob = ODEProblem(fiip, u0, (0.0, 10.0), p)
sol = solve(prob, Tsit5())
loss(u0, p) = sum(solve(prob, Tsit5(), u0 = u0, p = p, saveat = 0.1))
du = zeros(size(u0))
dp = zeros(size(p))
dp[1] = 1.0
Enzyme.autodiff(Reverse, loss, Active, Duplicated(u0, du), Duplicated(p, dp))

When I initially posted about this on Slack, I was on my desktop in Julia 1.10 and got ERROR: AssertionError: Enzyme Internal Error: did not have sret when expected f=; Function Attrs: alwaysinline mustprogress willreturn with a preposterously long stacktrace that I can pull and attach here when I get back to my desk at work tomorrow.

Interestingly, on my laptop at home on Julia 1.9, when I tried to reproduce this, something different happens (attached to this message because it was still too long to paste in): ODE_Enzyme_Julia19.txt

Here's the full environment...

(@v1.9) pkg> st
Status `~/.julia/environments/v1.9/Project.toml`
  [ac637c84] AbbreviatedStackTraces v0.2.3
  [6e4b80f9] BenchmarkTools v1.5.0
  [acf6eb54] DFTK v0.6.20
⌅ [7da242da] Enzyme v0.11.20
  [f6369f11] ForwardDiff v0.10.36
  [5fb14364] OhMyREPL v0.5.28
⌅ [1dea7af3] OrdinaryDiffEq v6.66.0
  [14b8a8f1] PkgTemplates v0.7.52
  [c3e4b0f8] Pluto v0.19.46
  [295af30f] Revise v3.5.18
⌅ [1ed8b502] SciMLSensitivity v7.51.0
  [b8865327] UnicodePlots v3.6.4
  [1986cc42] Unitful v1.21.0
  [a7773ee8] UnitfulAtomic v1.0.0
  [37e2e46d] LinearAlgebra
Info Packages marked with ⌅ have new versions available but compatibility constraints restrict them from upgrading. To see why use `status --outdated`

As I said, I'll grab the output from running this on 1.10 and the associated environment details and add those here tomorrow. LMK if any other info would be helpful in the meantime!

wsmoses commented 2 months ago

@rkurchin it looks like you're on a really old version of Enzyme, what happens if you use the latest?

rkurchin commented 2 months ago

Yup, I haven't set up my laptop at home on juliaup yet and it still has 1.9 (those were the latest versions of everything that it would let me get); now I'm back at work here's the output from the same thing on Julia 1.10 (top-line error is ERROR: AssertionError: Enzyme Internal Error: did not have sret when expected): ODE_Enzyme_Julia110.txt

and environment info...

(@v1.10) pkg> st
Status `~/.julia/environments/v1.10/Project.toml`
  [7da242da] Enzyme v0.12.36
  [5fb14364] OhMyREPL v0.5.28
  [1dea7af3] OrdinaryDiffEq v6.89.0
  [14b8a8f1] PkgTemplates v0.7.52
  [295af30f] Revise v3.5.18
  [1ed8b502] SciMLSensitivity v7.67.0
  [1e6cf692] TestEnv v1.102.0
wsmoses commented 2 months ago

Okay, this should be solved by https://github.com/EnzymeAD/Enzyme.jl/pull/1829

Note that for this to run end to end, you'll probably want to get rid of the type instability as follows:

using SciMLSensitivity, OrdinaryDiffEq, Enzyme

function fiip(du, u, p, t)
    du[1] = dx = p[1] * u[1] - p[2] * u[1] * u[2]
    du[2] = dy = -p[3] * u[2] + p[4] * u[1] * u[2]
end
p = [1.5, 1.0, 3.0, 1.0];
u0 = [1.0; 1.0];
prob = ODEProblem(fiip, u0, (0.0, 10.0), p)
sol = solve(prob, Tsit5())
loss(u0, p, prob) = sum(solve(prob, Tsit5(), u0 = u0, p = p, saveat = 0.1))
du = zeros(size(u0))
dp = zeros(size(p))
dp[1] = 1.0
Enzyme.autodiff(Reverse, loss, Active, Duplicated(u0, du), Duplicated(p, dp), Const(prob))

Also seemignly this gives a nice speedup which is cool:

julia> @btime Zygote.gradient((u0,p)->loss(u0,p,prob), u0, p)
  8.795 ms (51250 allocations: 2.71 MiB)
([-39.49430913141475, -8.63188820376318], [7.349039360098285, -159.3107987785755, 74.93924800431942, -339.3272380777995])

julia> @btime Enzyme.autodiff(Reverse, loss, Active, Duplicated(u0, du), Duplicated(p, dp), Const(prob))
  456.900 μs (2834 allocations: 236.31 KiB)
((nothing, nothing, nothing),)

cc @ChrisRackauckas since integration speedups are always nice

ChrisRackauckas commented 2 months ago

Yeah what was the issue there? I don't quite know what that PR was doing.

Note that this is hitting the adjoint method on the ODE using the overload, so it's not quite "AD through ODE solution". That's https://github.com/SciML/OrdinaryDiffEq.jl/pull/2282 which still needs https://github.com/SciML/DiffEqBase.jl/pull/1073 but is basically done. I may just solve that today.