Closed rkurchin closed 2 months ago
@rkurchin it looks like you're on a really old version of Enzyme, what happens if you use the latest?
Yup, I haven't set up my laptop at home on juliaup
yet and it still has 1.9 (those were the latest versions of everything that it would let me get); now I'm back at work here's the output from the same thing on Julia 1.10 (top-line error is ERROR: AssertionError: Enzyme Internal Error: did not have sret when expected
): ODE_Enzyme_Julia110.txt
and environment info...
(@v1.10) pkg> st
Status `~/.julia/environments/v1.10/Project.toml`
[7da242da] Enzyme v0.12.36
[5fb14364] OhMyREPL v0.5.28
[1dea7af3] OrdinaryDiffEq v6.89.0
[14b8a8f1] PkgTemplates v0.7.52
[295af30f] Revise v3.5.18
[1ed8b502] SciMLSensitivity v7.67.0
[1e6cf692] TestEnv v1.102.0
Okay, this should be solved by https://github.com/EnzymeAD/Enzyme.jl/pull/1829
Note that for this to run end to end, you'll probably want to get rid of the type instability as follows:
using SciMLSensitivity, OrdinaryDiffEq, Enzyme
function fiip(du, u, p, t)
du[1] = dx = p[1] * u[1] - p[2] * u[1] * u[2]
du[2] = dy = -p[3] * u[2] + p[4] * u[1] * u[2]
end
p = [1.5, 1.0, 3.0, 1.0];
u0 = [1.0; 1.0];
prob = ODEProblem(fiip, u0, (0.0, 10.0), p)
sol = solve(prob, Tsit5())
loss(u0, p, prob) = sum(solve(prob, Tsit5(), u0 = u0, p = p, saveat = 0.1))
du = zeros(size(u0))
dp = zeros(size(p))
dp[1] = 1.0
Enzyme.autodiff(Reverse, loss, Active, Duplicated(u0, du), Duplicated(p, dp), Const(prob))
Also seemignly this gives a nice speedup which is cool:
julia> @btime Zygote.gradient((u0,p)->loss(u0,p,prob), u0, p)
8.795 ms (51250 allocations: 2.71 MiB)
([-39.49430913141475, -8.63188820376318], [7.349039360098285, -159.3107987785755, 74.93924800431942, -339.3272380777995])
julia> @btime Enzyme.autodiff(Reverse, loss, Active, Duplicated(u0, du), Duplicated(p, dp), Const(prob))
456.900 μs (2834 allocations: 236.31 KiB)
((nothing, nothing, nothing),)
cc @ChrisRackauckas since integration speedups are always nice
Yeah what was the issue there? I don't quite know what that PR was doing.
Note that this is hitting the adjoint method on the ODE using the overload, so it's not quite "AD through ODE solution". That's https://github.com/SciML/OrdinaryDiffEq.jl/pull/2282 which still needs https://github.com/SciML/DiffEqBase.jl/pull/1073 but is basically done. I may just solve that today.
Reposting from Slack at @wsmoses' request...
I’m trying to differentiate through an ODE solution with Enzyme and running into some errors. Basically, I started with the example on this docs page and just tried to Enzyme-ify it like so:
When I initially posted about this on Slack, I was on my desktop in Julia 1.10 and got
ERROR: AssertionError: Enzyme Internal Error: did not have sret when expected f=; Function Attrs: alwaysinline mustprogress willreturn
with a preposterously long stacktrace that I can pull and attach here when I get back to my desk at work tomorrow.Interestingly, on my laptop at home on Julia 1.9, when I tried to reproduce this, something different happens (attached to this message because it was still too long to paste in): ODE_Enzyme_Julia19.txt
Here's the full environment...
As I said, I'll grab the output from running this on 1.10 and the associated environment details and add those here tomorrow. LMK if any other info would be helpful in the meantime!