Open willtebbutt opened 1 month ago
Tapir.jl does not perform as well as it could on functions like the following:
function foo!(y::Vector{Float64}, x::Vector{Float64})
@inbounds @simd for n in eachindex(x)
y[n] = y[n] + x[n]
end
return y
end
For example, on my computer:
y = randn(4096)
x = randn(4096)
julia> @benchmark foo!($y, $x)
BenchmarkTools.Trial: 10000 samples with 173 evaluations.
Range (min … max): 547.150 ns … 3.138 μs ┊ GC (min … max): 0.00% … 0.00%
Time (median): 646.633 ns ┊ GC (median): 0.00%
Time (mean ± σ): 682.488 ns ± 116.548 ns ┊ GC (mean ± σ): 0.00% ± 0.00%
▄██▂
▁▁▂▄▇████▇▇▇▆▆▅▅▅▅▄▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
547 ns Histogram: frequency by time 1.18 μs <
Memory estimate: 0 bytes, allocs estimate: 0.
rule = Tapir.build_rrule(foo!, y, x);
foo!_d = zero_fcodual(foo!)
y_d = zero_fcodual(y)
x_d = zero_fcodual(x)
out, pb!! = rule(foo!_d, y_d, x_d);
julia> @benchmark ($rule)($foo!_d, $y_d, $x_d)[2](NoRData())
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 64.042 μs … 202.237 μs ┊ GC (min … max): 0.00% … 0.00%
Time (median): 78.675 μs ┊ GC (median): 0.00%
Time (mean ± σ): 75.763 μs ± 10.175 μs ┊ GC (mean ± σ): 0.00% ± 0.00%
▇ ▇ ▇ ▄▂ ▅ ▃ ▆ ▅▆ █▄ ▄ ▁▂ ▂ ▁ ▂ ▃
█▃█▃█▄██▁▄▁▆▄▁▁█▄█▆██████████▇▄██▃▅█▃▃█▆▆▇█▆▆▆▇▆▄▁█▃▃▃▅█▅▃▁█ █
64 μs Histogram: log(frequency) by time 108 μs <
Memory estimate: 0 bytes, allocs estimate: 0.
So the performance ratio is roughly 64 / 0.5
which is 128
.
Note that this is not due to type-instabilities. One way to convince yourself of this is that there are no allocations required to run AD, which would most certainly not be the case were there type instabilities. Rather, the problems are to do with the overhead associated to our implementation of reverse-mode AD.
To see this, take a look at the optimised IR for foo!
:
2 1 ── %1 = Base.arraysize(_3, 1)::Int64 │╻╷╷╷╷ macro expansion
│ %2 = Base.slt_int(%1, 0)::Bool ││╻╷╷╷╷ eachindex
│ %3 = Core.ifelse(%2, 0, %1)::Int64 │││╻ axes1
│ %4 = %new(Base.OneTo{Int64}, %3)::Base.OneTo{Int64} ││││┃││││ axes
└─── goto #14 if not true │╻ macro expansion
2 ── %6 = Base.slt_int(0, %3)::Bool ││╻ <
└─── goto #12 if not %6 ││
3 ── nothing::Nothing │
4 ┄─ %9 = φ (#3 => 0, #11 => %27)::Int64 ││
│ %10 = Base.slt_int(%9, %3)::Bool ││╻ <
└─── goto #12 if not %10 ││
5 ── %12 = Base.add_int(%9, 1)::Int64 ││╻╷ simd_index
└─── goto #9 if not false │││╻ getindex
6 ── %14 = Base.slt_int(0, %12)::Bool ││││╻ >
│ %15 = Base.sle_int(%12, %3)::Bool ││││╻ <=
│ %16 = Base.and_int(%14, %15)::Bool ││││╻ &
└─── goto #8 if not %16 ││││
7 ── goto #9 │
8 ── invoke Base.throw_boundserror(%4::Base.OneTo{Int64}, %12::Int64)::Union{}
└─── unreachable ││││
9 ┄─ goto #10 │
10 ─ goto #11 │
11 ─ %23 = Base.arrayref(false, _2, %12)::Float64 ││╻╷ macro expansion
│ %24 = Base.arrayref(false, _3, %12)::Float64 │││┃ getindex
│ %25 = Base.add_float(%23, %24)::Float64 │││╻ +
│ Base.arrayset(false, _2, %25, %12)::Vector{Float64} │││╻ setindex!
│ %27 = Base.add_int(%9, 1)::Int64 ││╻ +
│ $(Expr(:loopinfo, Symbol("julia.simdloop"), nothing))::Nothing │╻ macro expansion
└─── goto #4 ││
12 ┄ goto #14 if not false ││
13 ─ nothing::Nothing │
5 14 ┄ return _2 │
The performance-critical chunk of the loop happens between %23
and %27
. Tapir.jl does basically the same kind of thing for each of these lines, so we just look at %23
:
%23_ = rrule!!(zero_fcodual(Base.arrayref), zero_fcodual(false), _2, %12)
%23 = %23[1]
push!(%23_pb_stack, %23[2])
In short, we run the rule, pull out the first element of the result, and push the pullback to the stack for use on the reverse-pass.
So there is at least one really large obvious source of overhead here: pushing to / popping from the stacks. If you take a look at the pullbacks for the arrayref
calls, you'll see that they contain:
This information is necessary for AD, but
4096
references to the same array to a stack and popping them, which is wasteful, andWhat's not obvious here, but is also important, is that the call to push!
tends to get inlined and contains a branch. This prevents LLVM from vectorising the loop, thus prohibiting quite a lot of optimisation.
Now, Tapir.jl is implemented in such a way that, if the pullback for a particular function is a singleton / doesn't carry around any information, the associated pullback stack is eliminated entirely. Moreover, just reducing the amount of memory stored at each iteration should reduce memory pressure. Consequently, a good strategy for making progress is to figure out how to reduce the amount of stuff which gets stored in the pullback stacks. The two points noted above provide obvious starting points.
In short: ammend the rule interface such that the arguments to the forwards pass are also made available on the reverse pass.
For example, the arrayref
rule is presently something along the lines of
function rrule!!(::CoDual{typeof(arrayref)}, inbounds::CoDual{Bool}, x::CoDual{Vector{Float64}}, ind::CoDual{Int})
_ind = primal(ind)
dx = tangent(x)
function arrayref_pullback(dy)
dx[_ind] += dy
return NoRData(), NoRData(), NoRData(), NoRData()
end
return CoDual(primal(x)[_ind], tangent(x)[_ind]), arrayref_pullback
end
This skips some details, but the important point is that _ind
and dx
are closed over, and are therefore stored in arrayref_pullback
.
Under the new interface, this would look something like
function rrule!!(::CoDual{typeof(arrayref)}, inbounds::CoDual{Bool}, x::CoDual{Vector{Float64}}, ind::CoDual{Int})
function arrayref_pullback(dy, ::CoDual{typeof(arrayref)}, ::CoDual{Bool}, x::CoDual{Vector{Float64}}, ind::CoDual{Int})
_ind = primal(ind)
dx = tangent(x)
dx[_ind] += dy
return NoRData(), NoRData(), NoRData(), NoRData()
end
return CoDual(primal(x)[_ind], tangent(x)[_ind]), arrayref_pullback
end
In this version of the rule, arrayref_pullback
is a singleton because it does not close over any data from the enclosing rrule!!
.
So this interface change frees up Tapir.jl to provide the arguments on the reverse-pass in whichever way it pleases. In this particular example, both x
and y
are arguments to foo!
, so applying this new interface recursively would give us direct access to them on the reverse pass by construction.
A similar strategy could be employed for variables which aren't arguments by putting them in the storage shared by the forwards and reverse passes.
It's impossible to know for sure how much of an effect this would have, but doing this alone would more than halve the memory requirement for arrayref
(a Vector{Float64}
knows its address in memory and its length, which requires 16B of memory, vs an index which is just an Int
which takes 8B of memory), and do even more for arrayset
(it requires references to the primal array and to the shadow). Since the pullback for +
is already a singleton in both the Float64
and Int
case, this would more than halve the memory footprint of the loop.
I won't address how we could make use of induction variable analysis here because I'm still trying to get my head around exactly how is easiest to go about it. Rather, just note that the above interface change is necessary in order to make use of the results of induction variable analysis -- the purpose of induction variable analysis would be to avoid having to store the index on each iteration of the loop, and to just re-compute it on the reverse pass, and give it to the pullbacks. The above change to the interface would permit this.
Another obvious optimisation is to analyse the trip count, and pre-allocate the (necessary) pullback stacks in order to avoid branching during execution (i.e. checking that they're long enough to store the next pullback, and allocating more memory if not).
This is related to induction variable analysis, so we'd probably want to do that first.
Doing this kind of optimisation would enable vectorisation to happen more effectively in AD, as would could completely eliminate branching from a number of tight loops.
Good investigations; it's probably okay to keep this issue open instead of transferring discussions here into docs.
I still need to add more concrete info about what loop optimisations are possible, but here's a summary of the state of affairs currently:
map
,broadcast
,mapreduce
, and any other higher-order functions I’ve forgotten about, all lower to loops in the CFG. Tapir.jl doesn’t have rules for them, so Tapir.jl sees these loops.sin(cos(exp(x[n])))
), the time spent managing “overhead” associated to looping (e.g. logging stuff on the forwards pass at each iteration which you need on the reverse-pass) is small in comparison to the time spent doing the work that you care about (e.g. computingsin(cos(exp(x[n])))
and doing AD on each operation in it etc)sum
is an extreme case of this, because adding twoFloat64
s together at each iteration is about the cheapest differentiable operation that you could imagine doing. Moreover, the current way that we handle looping in Tapir.jl “gets in the way” of vectorisation (on the forwards-pass and reverse-pass).kron
andsum
. I imagine they’ll be especially great onsum
as they should be able to “get out of the way” of vectorisation (i.e. things should vectorise nicely) in many cases.That we just rely on everything boiling down to the same kind of looping structure in the CFG is a great advantage of this approach -- basically everything CPU-based that’s performant gets reduced to a loop in the CFG (specifically, a thing called a “Natural Loop” in compiler optimisation terminology). There are well-established optimisation strategies for loops, so we don’t need to implement separate rules for all the different higher-order functions to get good performance, nor do we need to tell people to steer clear of writing for or while loops. Rather, we just optimise these so-called “natural loop” structures which appear in the CFG, and then everything (or, rather, most things) will (should) be fast.
(The situation in which this strategy breaks down is if people use
@goto
to produce certain kinds of “weird” looping structures. Such structures will only ever be as performant as they are currently. Frankly, it’s not bad, but we should probably discourage people from using@goto
, which is definitely something that I can live with)