xzackli / Bolt.jl

differentiable boltzmann code
MIT License
42 stars 5 forks source link

Enzyme isn't ready for use with Bolt #59

Open xzackli opened 2 years ago

xzackli commented 2 years ago

I played with Enzyme a little bit, and I suspect it's not ready for use with our package. It can't differentiate simple ODEs at present. There's a fair amount of linear algebra in OrdinaryDiffEq, so it's probably tripping up on BLAS.

# this will crash
using OrdinaryDiffEq
using Enzyme

f(u,p,t) = 1.01*u
function test(u0)
    tspan = (0.0,1.0)
    prob = ODEProblem(f,u0,tspan)
    sol = solve(prob,Rodas4(),reltol=1e-8,abstol=1e-8)
    return sol(1.0)
end

autodiff(test, Active(1.0)) # (g, 1.0)
xzackli commented 2 years ago

https://gist.github.com/xzackli/7c8819f3e7b43f16481e5d909c0b7764

xzackli commented 2 years ago

However, maybe iterative methods like #57 will be better? Nevertheless, our background and RECFAST are basically like the example above -- we need gradients through ODE solves.

jmsull commented 2 years ago

I am surprised such a simple example fails given the fact Enzyme was applied to ODEs previously? I guess they were not using Rodas4 so not hitting BLAS (or whatever the issue is)? Apparently, BLAS accounts for edge cases "99% of the time" here with Enzyme - haven't tried to read all the details here but we can leave Enzyme aside for now.

For the iterative methods we still need to solve the ode part with DE solvers so this problem is not going away.

marius311 commented 2 years ago

Have you guys messed around with https://github.com/JuliaDiff/Diffractor.jl yet? I think that will also eventually have pretty well optimized scalar forward and reverse mode (scalar meaning should work well through loops / scalar indexing, unlike eg Zygote)

(sorry to interject I have major FOMO seeing you both do cool stuff here :grin: )

xzackli commented 2 years ago

@jmsull that's a really interesting thread! I'm glad to learn there are CS people at the julialab working on Enzyme BLAS support. Yeah, let's just leave this aside for a bit.

@marius311 I'm going to wait until their first tagged release, but it's exciting stuff! My understanding is that Diffractor needs compiler improvements from Julia 1.8, which makes it a bit harder to play with. At the very least, it will be nice to use a forward-mode AD package that supports ChainRules.

marius311 commented 2 years ago

Yea, I think you'd want 1.8 (so Julia#master atm) but I figured I'd mention since Enzyme is still pretty early too. Fwiw I have played with it doing basic stuff and it is definitely working, but certainly not ready to actually depend on yet. Our of curiousity, is the ForwardDiff stuff already used here not good enough in some ways?

jmsull commented 2 years ago

@marius311 Interested to try it out then - we tried Enzyme even though it is early since the developers (and others) recommended it to us at the AD workshop - but happy to try out Diffractor as well if you say it's working.

The paper I linked above concludes that

"Our results show a strong performance advantage for automatic differentiation based discrete sensitivity analysis for forward-mode sensitivity analysis on sufficiently small systems, and an advantage for continuous adjoint sensitivity analysis for sufficiently large systems." (cf Fig. 2)

So we (or at least I) thought going to reverse mode might show performance gains since this is a large ODE system (at least with high ell_max).

xzackli commented 2 years ago

Just to add on to Jamie's comment, my understanding is that for n ODEs and p parameters, forward-mode AD will scale like O(np) whereas adjoint methods scale like O(n+p) but with a large overhead. arxiv:1812.01892 find that for small problems (like n + p < 50-100) the overhead of adjoint methods isn't worth it (also it shows Enzyme having some remarkable performance characteristics) and forward-mode AD still wins. Since the hierarchy tends to require fairly large systems to accurately solve for the transfer functions, this is a problem. Also, we have some ambition to involve ML so p can become large.

This does have some nice implications for the AD properties of the hierarchy-less method, since it reduces the system size so much.

xzackli commented 2 years ago

It occurs to me due to this discussion that perhaps people were saying that we should use Enzyme only for ODE internal vjps, i.e. in the sensitivity docs. We would use something like Zygote on the outside?

xzackli commented 2 years ago

Here's a somewhat realistic demo. Consider a future hierarchy-less situation where we have a small, stiff ODE system (+ some iterative methods) and some number of parameters we want to sample or optimize over. The rober example from the SciML docs isn't a bad placeholder.

using DiffEqSensitivity, OrdinaryDiffEq, ForwardDiff, Zygote, BenchmarkTools

function rober(du,u,p,t)
    y₁,y₂,y₃ = u
    k₁,k₂,k₃ = p[1], p[2], p[3]
    du[1] = -k₁*y₁+k₃*y₂*y₃
    du[2] =  k₁*y₁-k₂*y₂^2-k₃*y₂*y₃
    du[3] =  k₂*y₂^2 + sum(p)
    nothing
end

function run_benchmarks()

    function sum_of_solution_fwd(x)
        _prob = ODEProblem(rober,x[1:3],(0.0,1e5),x[4:end])
        sum(solve(_prob,Rodas5(),reltol=1e-6,abstol=1e-6))
    end

    function sum_of_solution_CASA(x)
        sensealg = QuadratureAdjoint()  # change me, lots of choices here (arXiv:1812.01892)
        _prob = ODEProblem(rober,x[1:3],(0.0,1e5),x[4:end])
        sum(solve(_prob,Rodas5(),reltol=1e-6,abstol=1e-6,sensealg=sensealg))
    end

    u0 = [1.0,0.0,0.0]
    p = rand(256)  # change me, the number of parameters

    @btime ForwardDiff.gradient($sum_of_solution_fwd,[$u0;$p])
    @btime Zygote.gradient($sum_of_solution_CASA,[$u0; $p])

    nothing
end

run_benchmarks()
  38.065 ms (18168 allocations: 4.53 MiB)
  18.599 ms (117672 allocations: 21.47 MiB)

Note that the title of this issue still appears to be correct: using Enzyme for vjp still is broken.

sensealg = QuadratureAdjoint(autojacvec=EnzymeVJP())
ChrisRackauckas commented 2 years ago

Found this issue because of your talk. Here's the workaround for right now:

using SciMLSensitivity, OrdinaryDiffEq, ForwardDiff, Zygote, BenchmarkTools

function rober(du,u,p,t)
    y₁,y₂,y₃ = u
    k₁,k₂,k₃ = p[1], p[2], p[3]
    du[1] = -k₁*y₁+k₃*y₂*y₃
    du[2] =  k₁*y₁-k₂*y₂^2-k₃*y₂*y₃
    du[3] =  k₂*y₂^2 + sum(p)
    nothing
end

function run_benchmarks()

    function sum_of_solution_fwd(x)
        _prob = ODEProblem(rober,x[1:3],(0.0,1e5),x[4:end])
        sum(solve(_prob,Rodas5(),reltol=1e-6,abstol=1e-6))
    end

    function sum_of_solution_CASA(x)
        sensealg = QuadratureAdjoint(autojacvec=EnzymeVJP())
        _prob = ODEProblem(rober,x[1:3],(0.0,1e5),x[4:end])
        sum(solve(_prob,Rodas5(autodiff=false),reltol=1e-6,abstol=1e-6,sensealg=sensealg))
    end

    u0 = [1.0,0.0,0.0]
    p = rand(256)  # change me, the number of parameters

    @btime ForwardDiff.gradient($sum_of_solution_fwd,[$u0;$p])
    @btime Zygote.gradient($sum_of_solution_CASA,[$u0; $p])

    nothing
end

run_benchmarks()

# 11.490 ms (25068 allocations: 5.52 MiB)
# 2.956 ms (11024 allocations: 9.73 MiB)

The issue is mixing the forward-mode Jacobian for the nonlinear solver with the reverse-mode. This isn't too hard to fix I think, I'll make this into a test case.

xzackli commented 2 years ago

Thanks @ChrisRackauckas, this is exciting stuff -- I think there will be a substantial (order of magnitude?) performance improvement for us. We'll play with the workaround for now and watch for those changes to SciMLSensitivity.

I'd love for this example to make it into the tests. It's basically a cartoon for the problem we're trying to solve: stiff evolution being compared to data with a likelihood evaluation.

ChrisRackauckas commented 2 years ago

Fixed on SciMLSensitivity v7.2.0 and this is now a test. Let me know if you run into anything else. For reference the test set of mixing stiff solvers with adjoints is https://github.com/SciML/SciMLSensitivity.jl/blob/master/test/stiff_adjoints.jl and just had a blind spot there.