SciML / JumpProcesses.jl

Build and simulate jump equations like Gillespie simulations and jump diffusions with constant and state-dependent rates and mix with differential equations and scientific machine learning (SciML)
https://docs.sciml.ai/JumpProcesses/stable/
Other
135 stars 35 forks source link

Efficency of ExtendedJumpArray broadcasting in ode_interpolant #335

Closed meson800 closed 10 months ago

meson800 commented 11 months ago

Background

Continuing to optimize a system with VariableRateJumps and callbacks, I've found that performing ODE interpolations on ExtendedJumpArray's is consuming around ~80% of the total runtime of the solve. About 10% is actually in the various do_step backtraces, with most of the runtime is in the find_first_continuous_callback call. Of that, the vast majority is happening in Base.getindex(A:ExtendedJumpArray)

Problem

In particular, it looks like the combination of ExtendedJumpArray's, @muladd, and FastBroadcast is causing unoptimal code generation. For my use case, the problematic function is here:

@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k,
    cache::Union{Tsit5ConstantCache, Tsit5Cache},
    idxs::Nothing, T::Type{Val{0}})
    @tsit5pre0
    @inbounds @.. broadcast=false out=y₀ +
                                      dt *
                                      (k[1] * b1Θ + k[2] * b2Θ + k[3] * b3Θ + k[4] * b4Θ +
                                       k[5] * b5Θ + k[6] * b6Θ + k[7] * b7Θ)
    out
end

but it seems to affect others as well.

Examining this with Cthulhu and friends, looking at the LLVM and native code, etc etc shows an absolute morass of branching, with something like 4800 (!) branches in the assembly. The branches are coming from inlining where it's using the ExtendedJumpArray index check. This churn means that the compiler doesn't seem to be able to loop unroll, since it has to repeatedly check "is index < length(jump_array.u)", over and over and over.

Solution

I'm guessing that this would be a lot faster and not be churning the branch predictor with a bunch of mostly-useless (<=) calls to switch between the .u and .jump_u members, if we could somehow turn these calls into separate calls like

out.u = y0.u + dt * ...
out.jump_u = y0.jump_u + dt * ...

I fixed this by adding this dispatch to my usercode (I unpacked one of the macros because getting the imports right was annoying, but it's just the original ode_interpolant with the .u and .jump_u parts unrolled):

using MuladdMacro, FastBroadcast
using OrdinaryDiffEq: constvalue, Tsit5Interp
function OrdinaryDiffEq._ode_interpolant!(out::ExtendedJumpArray, Θ, dt, y₀::ExtendedJumpArray, y₁, k,
    cache::Union{OrdinaryDiffEq.Tsit5ConstantCache, OrdinaryDiffEq.Tsit5Cache},
    idxs::Nothing, T::Type{Val{0}})
    OrdinaryDiffEq.@tsit5unpack
    Θ² = Θ * Θ
    b1Θ = Θ * @evalpoly(Θ, r11, r12, r13, r14)
    b2Θ = Θ² * @evalpoly(Θ, r22, r23, r24)
    b3Θ = Θ² * @evalpoly(Θ, r32, r33, r34)
    b4Θ = Θ² * @evalpoly(Θ, r42, r43, r44)
    b5Θ = Θ² * @evalpoly(Θ, r52, r53, r54)
    b6Θ = Θ² * @evalpoly(Θ, r62, r63, r64)
    b7Θ = Θ² * @evalpoly(Θ, r72, r73, r74)
    @muladd @inbounds FastBroadcast.@.. broadcast=false out.u=y₀.u +
                                      dt *
                                      (k[1].u * b1Θ + k[2].u * b2Θ + k[3].u * b3Θ + k[4].u * b4Θ +
                                       k[5].u * b5Θ + k[6].u * b6Θ + k[7].u * b7Θ)
    @muladd @inbounds FastBroadcast.@.. broadcast=false out.jump_u=y₀.jump_u +
                                      dt *
                                      (k[1].jump_u * b1Θ + k[2].jump_u * b2Θ + k[3].jump_u * b3Θ + k[4].jump_u * b4Θ +
                                       k[5].jump_u * b5Θ + k[6].jump_u * b6Θ + k[7].jump_u * b7Θ)
    out
end

This dramatically reduces the runtime, and inspecting the LLVM code shows "only" 123 branches with good loop unrolling. There's still a boat-load of allocations happening in handle_callbacks, but I'll deal with that with a separate issue/PR.

Questions

isaacsas commented 11 months ago

I don't think we'd want to add ODE method dependent dispatches. That would be a nightmare to put together and maintain. Dispatches that are not ODE-method dependent would be fine (i.e. things that are used by all / many methods like your last PR, so more dispatches on things in DiffEqBase, SciMLBase, or Julia's Base). But maybe there is a ExtendedJumpArray broadcast feature that just isn't implemented and if dispatched would fix the issue here? (I really have no familiarity with the broadcast API unfortunately, so can't help with suggestions. Maybe @ChrisRackauckas has some suggestions.)

@.. comes from FastBroadcast.jl if you want to look at what it does:

https://github.com/YingboMa/FastBroadcast.jl

ChrisRackauckas commented 11 months ago

https://github.com/YingboMa/FastBroadcast.jl/blob/master/src/FastBroadcast.jl#L19

Try overriding this to be false so it doesn't use the linear indexing?

I don't think we'd want to add ODE method dependent dispatches. That would be a nightmare to put together and maintain.

Yes, it should just be one thing about broadcast lowering

meson800 commented 11 months ago

Try overriding this to be false so it doesn't use the linear indexing?

Overriding that to false doesn't seem to solve the problem; there's still a bunch of calls to getindex and specifically x <= length(eja.u)

I've been thinking through the broadcast interface, it feels like what we want is a broadcast style that ends up flattening the broadcast kernel, then applies it on both u and jump_u. It does seem like, at least in 2018, someone was trying to do something similar using the broadcast rules: https://github.com/JuliaLang/julia/issues/27988#issuecomment-403319535

I'll update if I make any progress. Right now my debugging frustratingly seems to show that on broadcast operations it's not using the defined ExtendedJumpArrayStyle, e.g. this code:

using JumpProcesses
k1 = ExtendedJumpArray{Float64, 1, Vector{Float64}, Vector{Float64}}(rand(5), rand(2))
k2 = ExtendedJumpArray{Float64, 1, Vector{Float64}, Vector{Float64}}(rand(5), rand(2))
@code_warntype k1 .+ k2

shows:

Arguments
  #self#::Core.Const(var"##dotfunction#1496#37"())
  x1::ExtendedJumpArray{Float64, 1, Vector{Float64}, Vector{Float64}}
  x2::ExtendedJumpArray{Float64, 1, Vector{Float64}, Vector{Float64}}
Body::Vector{Float64}
1 ─ %1 = Base.broadcasted(Main.:+, x1, x2)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(+), Tuple{ExtendedJumpArray{Float64, 1, Vector{Float64}, Vector{Float64}}, ExtendedJumpArray{Float64, 1, Vector{Float64}, Vector{Float64}}}}
│   %2 = Base.materialize(%1)::Vector{Float64}
└──      return %2

which is actually just wrong; this should not just return a Vector{Float64}. I don't know why it's doing this, but I checked that this happens even with a clean Julia environment (only [ccbc3e58] JumpProcesses v9.7.2 when running Pkg.status)

Running k1 + k2 results in an ExtendedJumpArray.

ChrisRackauckas commented 11 months ago

which is actually just wrong; this should not just return a Vector{Float64}. I don't know why it's doing this, but I checked that this happens even with a clean Julia environment (only [ccbc3e58] JumpProcesses v9.7.2 when running Pkg.status)

This means it's hitting the fallback broadcast style, and when it's using the AbstractArray style then it defaults to returning an Array. And that would use indexing and be type unstable if u and jump_u are not the same.

The thing to look at would probably be https://github.com/jonniedie/ComponentArrays.jl. It solves this kinds of problems quite nicely, so we may just want to lift some of its broadcast implementation. I've also considered completely removing the ExtendedJumpArray and just using a ComponentArray.

Though the issue is that it's somewhat magical to the user if they solve using a an Array of size 5 and then get a solution that's an array of size 7. The reason for this machinery being a bit different is to try to mask from the user the fact that it is constructing and solving a larger system. Potentially, the solution is to just use a ComponentArray but make the saving code allow for just saving the "Array" part.

meson800 commented 11 months ago

Thanks for the reference! It looks like ComponentArrays actually has very little broadcasting magic, but a much more advanced indexing magic. I'll see if there is a straightforward way to get the ExtendedJumpArray to actually use the broadcasting code that has already been written, otherwise I'll see if there is some indexing magic to use.

ChrisRackauckas commented 11 months ago

I think the fact that ComponentArrays forces everything to have the same type, thus you cannot end up in the situation where u is Int and jump_u is a Float64, is doing a lot of heavy lifting there. I am unsure if we should just promote the type.

meson800 commented 11 months ago

Yeah, I did test out getting broadcasting working and ran into something similar. The current broadcast code wasn't actually doing anything, but once I added two more interface functions in Broadcast, broadcasting "works", as in dot operations properly fuse and do what you expect and returns a ExtendedJumpArray. The limitations were:

I might try out the hacky thing first of just directly calling fast_materialize! inside the ExtendedJumpArray broadcast code, just to see if there are speed improvements. This would definitely not be a PR-worthy solution, but would at least let me know if it could be improved.

meson800 commented 10 months ago

I haven't figured out a solution to this so far though I'd like to; even a simple adding two ExtendedJumpArray's together is about five times slower than adding an equivalent Vector{Float64}'s, which feels bad.

Notes so far:

  1. The current broadcasting overloads that define the ExtendedJumpArrayStyle don't do anything, at least in the current Julia. This is because it defines a method for Broadcast.BroadcastStyle but not the Base.BroadcastStyle function, but even if you include this, broadcasting falls back to DefaultArrayStyle. Since it appears unused, I'm considering just replacing it.
  2. To handle dot-broadcasting, the current implementation of copyto! does the efficient thing we want, it unpacks the broadcast call and does the u and jump_u arrays separately. It just doesn't currently get called due to the ExtendedJumpArrayStyle not fully implementing the broadcasting interface.
  3. To address the problem of FastBroadcast as used in ode_interpolant, the @.. macro is replacing calls to materialize! with fast_materialize!. Unfortunately, that means that any special code we put into copyto! won't get called by FastBroadcast. Technically, we could define an overload for FastBroadcast.fast_materialize and FastBroadcast.fast_materialize! that are effectively identical to copyto! and just internally unpack the Broadcasted object.

At the very least, I'm trying to address 1) and 2). There are currently correctness issues with the current fallback to DefaultArrayStyle, like this happily working despite the u/jump_u mismatch:

ExtendedJumpArray(rand(100), rand(10)) .+ ExtendedJumpArray(rand(90), rand(20))

For 3), @ChrisRackauckas, do you think it's ok to add a dependency on FastBroadcast to this package? Or a Julia 1.9 extension if FastBroadcast is loaded? Because FastBroadcast.@.. is first lowering broadcasted code and then just replacing materialize! calls, there's no hook in the Broadcast interface we can use to override its behavior besides defining methods for fast_materialize and fast_materialize!

meson800 commented 10 months ago

I did it, at least for normal broadcasting! I should have a PR in today. I rewrote the broadcast rules and slightly changed how broadcast repacking works.

I'm working on the FastBroadcast overloads right now, confirming that they work as expected.

For the benchmark, I compare just against a linear vector. The old fallback mechanism used to be ~3-5x as slow, but now it is effectively the same. Checking with Cthulhu shows that efficient simd instructions are being emitted now.

using BenchmarkTools
bench_out_array = ExtendedJumpArray{Float64, 1, Vector{Float64}, Vector{Float64}}(rand(rng, 500000),
                                                                             rand(rng, 500000))
bench_in_array = ExtendedJumpArray{Float64, 1, Vector{Float64}, Vector{Float64}}(rand(rng, 500000),
                                                                             rand(rng, 500000))
base_out_array = rand(rng, 500000 * 2)
base_in_array = rand(rng, 500000 * 2)

function test_single_dot(out, array)
     @inbounds  @. out = array + 1.0 * array + 1.2 * array
end
test_single_dot(bench_out_array, bench_in_array)
@benchmark test_single_dot(bench_out_array, bench_in_array)
@benchmark test_single_dot(base_out_array, base_in_array)

benchmarks

isaacsas commented 10 months ago

I’d be fine with adding a dependency on FastBroadcast if that enables you to fix the issue. Please go ahead and add it if needed.