Closed meson800 closed 10 months ago
I don't think we'd want to add ODE method dependent dispatches. That would be a nightmare to put together and maintain. Dispatches that are not ODE-method dependent would be fine (i.e. things that are used by all / many methods like your last PR, so more dispatches on things in DiffEqBase, SciMLBase, or Julia's Base). But maybe there is a ExtendedJumpArray broadcast feature that just isn't implemented and if dispatched would fix the issue here? (I really have no familiarity with the broadcast API unfortunately, so can't help with suggestions. Maybe @ChrisRackauckas has some suggestions.)
@..
comes from FastBroadcast.jl if you want to look at what it does:
https://github.com/YingboMa/FastBroadcast.jl/blob/master/src/FastBroadcast.jl#L19
Try overriding this to be false so it doesn't use the linear indexing?
I don't think we'd want to add ODE method dependent dispatches. That would be a nightmare to put together and maintain.
Yes, it should just be one thing about broadcast lowering
Try overriding this to be false so it doesn't use the linear indexing?
Overriding that to false doesn't seem to solve the problem; there's still a bunch of calls to getindex
and specifically x <= length(eja.u)
I've been thinking through the broadcast interface, it feels like what we want is a broadcast style that ends up flattening the broadcast kernel, then applies it on both u
and jump_u
. It does seem like, at least in 2018, someone was trying to do something similar using the broadcast rules: https://github.com/JuliaLang/julia/issues/27988#issuecomment-403319535
I'll update if I make any progress. Right now my debugging frustratingly seems to show that on broadcast operations it's not using the defined ExtendedJumpArrayStyle
, e.g. this code:
using JumpProcesses
k1 = ExtendedJumpArray{Float64, 1, Vector{Float64}, Vector{Float64}}(rand(5), rand(2))
k2 = ExtendedJumpArray{Float64, 1, Vector{Float64}, Vector{Float64}}(rand(5), rand(2))
@code_warntype k1 .+ k2
shows:
Arguments
#self#::Core.Const(var"##dotfunction#1496#37"())
x1::ExtendedJumpArray{Float64, 1, Vector{Float64}, Vector{Float64}}
x2::ExtendedJumpArray{Float64, 1, Vector{Float64}, Vector{Float64}}
Body::Vector{Float64}
1 ─ %1 = Base.broadcasted(Main.:+, x1, x2)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(+), Tuple{ExtendedJumpArray{Float64, 1, Vector{Float64}, Vector{Float64}}, ExtendedJumpArray{Float64, 1, Vector{Float64}, Vector{Float64}}}}
│ %2 = Base.materialize(%1)::Vector{Float64}
└── return %2
which is actually just wrong; this should not just return a Vector{Float64}. I don't know why it's doing this, but I checked that this happens even with a clean Julia environment (only [ccbc3e58] JumpProcesses v9.7.2
when running Pkg.status)
Running k1 + k2
results in an ExtendedJumpArray.
which is actually just wrong; this should not just return a Vector{Float64}. I don't know why it's doing this, but I checked that this happens even with a clean Julia environment (only [ccbc3e58] JumpProcesses v9.7.2 when running Pkg.status)
This means it's hitting the fallback broadcast style, and when it's using the AbstractArray style then it defaults to returning an Array
. And that would use indexing and be type unstable if u
and jump_u
are not the same.
The thing to look at would probably be https://github.com/jonniedie/ComponentArrays.jl. It solves this kinds of problems quite nicely, so we may just want to lift some of its broadcast implementation. I've also considered completely removing the ExtendedJumpArray and just using a ComponentArray.
Though the issue is that it's somewhat magical to the user if they solve using a an Array
of size 5 and then get a solution that's an array of size 7. The reason for this machinery being a bit different is to try to mask from the user the fact that it is constructing and solving a larger system. Potentially, the solution is to just use a ComponentArray but make the saving code allow for just saving the "Array" part.
Thanks for the reference! It looks like ComponentArrays actually has very little broadcasting magic, but a much more advanced indexing magic. I'll see if there is a straightforward way to get the ExtendedJumpArray to actually use the broadcasting code that has already been written, otherwise I'll see if there is some indexing magic to use.
I think the fact that ComponentArrays forces everything to have the same type, thus you cannot end up in the situation where u
is Int and jump_u
is a Float64, is doing a lot of heavy lifting there. I am unsure if we should just promote the type.
Yeah, I did test out getting broadcasting working and ran into something similar. The current broadcast code wasn't actually doing anything, but once I added two more interface functions in Broadcast
, broadcasting "works", as in dot operations properly fuse and do what you expect and returns a ExtendedJumpArray
. The limitations were:
ExtendedJumpArray
code uses the type of the u
array, so even if you have a (u,jump_u) = (Float64, Int64)
setup, a single dot broadcast converts it to (Float64, Float64)
. It miiight be possible to propogate (Float64, Int64)
by encoding it into the broadcast style type, but I think this would also require manually writing style mixing rules to implement the expected integer/floating point type promotion rules, which would be gross. (e.g. if we multiply a (Float64, Int64)
by pi, we really do expect to get a (Float64, Float64)
out).ode_interpolant
section, I no longer see it take a lot of time in the inefficient get_index
....I now see it taking a lot of time in the default materialize!
implementation, which FastBroadcasts
falls back to. In particular, the @..
macro can't "see" the materialize!
calls inside the broadcast machinery so can't convert them to fast_materialize!
calls.I might try out the hacky thing first of just directly calling fast_materialize!
inside the ExtendedJumpArray
broadcast code, just to see if there are speed improvements. This would definitely not be a PR-worthy solution, but would at least let me know if it could be improved.
I haven't figured out a solution to this so far though I'd like to; even a simple adding two ExtendedJumpArray
's together is about five times slower than adding an equivalent Vector{Float64}
's, which feels bad.
Notes so far:
ExtendedJumpArrayStyle
don't do anything, at least in the current Julia. This is because it defines a method for Broadcast.BroadcastStyle
but not the Base.BroadcastStyle
function, but even if you include this, broadcasting falls back to DefaultArrayStyle
. Since it appears unused, I'm considering just replacing it.copyto!
does the efficient thing we want, it unpacks the broadcast call and does the u
and jump_u
arrays separately. It just doesn't currently get called due to the ExtendedJumpArrayStyle
not fully implementing the broadcasting interface.ode_interpolant
, the @..
macro is replacing calls to materialize!
with fast_materialize!
. Unfortunately, that means that any special code we put into copyto!
won't get called by FastBroadcast. Technically, we could define an overload for FastBroadcast.fast_materialize
and FastBroadcast.fast_materialize!
that are effectively identical to copyto!
and just internally unpack the Broadcasted
object.At the very least, I'm trying to address 1) and 2). There are currently correctness issues with the current fallback to DefaultArrayStyle
, like this happily working despite the u/jump_u mismatch:
ExtendedJumpArray(rand(100), rand(10)) .+ ExtendedJumpArray(rand(90), rand(20))
For 3), @ChrisRackauckas, do you think it's ok to add a dependency on FastBroadcast
to this package? Or a Julia 1.9 extension if FastBroadcast
is loaded? Because FastBroadcast.@..
is first lowering broadcasted code and then just replacing materialize!
calls, there's no hook in the Broadcast interface we can use to override its behavior besides defining methods for fast_materialize
and fast_materialize!
I did it, at least for normal broadcasting! I should have a PR in today. I rewrote the broadcast rules and slightly changed how broadcast repacking works.
I'm working on the FastBroadcast overloads right now, confirming that they work as expected.
For the benchmark, I compare just against a linear vector. The old fallback mechanism used to be ~3-5x as slow, but now it is effectively the same. Checking with Cthulhu shows that efficient simd instructions are being emitted now.
using BenchmarkTools
bench_out_array = ExtendedJumpArray{Float64, 1, Vector{Float64}, Vector{Float64}}(rand(rng, 500000),
rand(rng, 500000))
bench_in_array = ExtendedJumpArray{Float64, 1, Vector{Float64}, Vector{Float64}}(rand(rng, 500000),
rand(rng, 500000))
base_out_array = rand(rng, 500000 * 2)
base_in_array = rand(rng, 500000 * 2)
function test_single_dot(out, array)
@inbounds @. out = array + 1.0 * array + 1.2 * array
end
test_single_dot(bench_out_array, bench_in_array)
@benchmark test_single_dot(bench_out_array, bench_in_array)
@benchmark test_single_dot(base_out_array, base_in_array)
I’d be fine with adding a dependency on FastBroadcast if that enables you to fix the issue. Please go ahead and add it if needed.
Background
Continuing to optimize a system with VariableRateJumps and callbacks, I've found that performing ODE interpolations on ExtendedJumpArray's is consuming around ~80% of the total runtime of the solve. About 10% is actually in the various
do_step
backtraces, with most of the runtime is in thefind_first_continuous_callback
call. Of that, the vast majority is happening inBase.getindex(A:ExtendedJumpArray)
Problem
In particular, it looks like the combination of
ExtendedJumpArray
's,@muladd
, andFastBroadcast
is causing unoptimal code generation. For my use case, the problematic function is here:but it seems to affect others as well.
Examining this with Cthulhu and friends, looking at the LLVM and native code, etc etc shows an absolute morass of branching, with something like 4800 (!) branches in the assembly. The branches are coming from inlining where it's using the ExtendedJumpArray index check. This churn means that the compiler doesn't seem to be able to loop unroll, since it has to repeatedly check "is index < length(jump_array.u)", over and over and over.
Solution
I'm guessing that this would be a lot faster and not be churning the branch predictor with a bunch of mostly-useless (<=) calls to switch between the
.u
and.jump_u
members, if we could somehow turn these calls into separate calls likeI fixed this by adding this dispatch to my usercode (I unpacked one of the macros because getting the imports right was annoying, but it's just the original ode_interpolant with the
.u
and.jump_u
parts unrolled):This dramatically reduces the runtime, and inspecting the LLVM code shows "only" 123 branches with good loop unrolling. There's still a boat-load of allocations happening in
handle_callbacks
, but I'll deal with that with a separate issue/PR.Questions
@..
macro to do this type of code generation automatically?OrdinaryDiffEq
.