compintell / Tapir.jl

https://compintell.github.io/Tapir.jl/
MIT License
91 stars 1 forks source link

Document Loop Optimisation Opportunities #156

Open willtebbutt opened 1 month ago

willtebbutt commented 1 month ago

I still need to add more concrete info about what loop optimisations are possible, but here's a summary of the state of affairs currently:

That we just rely on everything boiling down to the same kind of looping structure in the CFG is a great advantage of this approach -- basically everything CPU-based that’s performant gets reduced to a loop in the CFG (specifically, a thing called a “Natural Loop” in compiler optimisation terminology). There are well-established optimisation strategies for loops, so we don’t need to implement separate rules for all the different higher-order functions to get good performance, nor do we need to tell people to steer clear of writing for or while loops. Rather, we just optimise these so-called “natural loop” structures which appear in the CFG, and then everything (or, rather, most things) will (should) be fast.

(The situation in which this strategy breaks down is if people use @goto to produce certain kinds of “weird” looping structures. Such structures will only ever be as performant as they are currently. Frankly, it’s not bad, but we should probably discourage people from using @goto , which is definitely something that I can live with)

willtebbutt commented 1 week ago

Tapir.jl does not perform as well as it could on functions like the following:

function foo!(y::Vector{Float64}, x::Vector{Float64})
    @inbounds @simd for n in eachindex(x)
        y[n] = y[n] + x[n]
    end
    return y
end

For example, on my computer:

y = randn(4096)
x = randn(4096)

julia> @benchmark foo!($y, $x)
BenchmarkTools.Trial: 10000 samples with 173 evaluations.
 Range (min … max):  547.150 ns …   3.138 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     646.633 ns               ┊ GC (median):    0.00%
 Time  (mean ± σ):   682.488 ns ± 116.548 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

       ▄██▂                                                      
  ▁▁▂▄▇████▇▇▇▆▆▅▅▅▅▄▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
  547 ns           Histogram: frequency by time         1.18 μs <

 Memory estimate: 0 bytes, allocs estimate: 0.

rule = Tapir.build_rrule(foo!, y, x);
foo!_d = zero_fcodual(foo!)
y_d = zero_fcodual(y)
x_d = zero_fcodual(x)
out, pb!! = rule(foo!_d, y_d, x_d);

julia> @benchmark ($rule)($foo!_d, $y_d, $x_d)[2](NoRData())
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  64.042 μs … 202.237 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     78.675 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   75.763 μs ±  10.175 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

  ▇ ▇ ▇ ▄▂       ▅ ▃  ▆ ▅▆ █▄ ▄  ▁▂         ▂       ▁        ▂ ▃
  █▃█▃█▄██▁▄▁▆▄▁▁█▄█▆██████████▇▄██▃▅█▃▃█▆▆▇█▆▆▆▇▆▄▁█▃▃▃▅█▅▃▁█ █
  64 μs         Histogram: log(frequency) by time       108 μs <

 Memory estimate: 0 bytes, allocs estimate: 0.

So the performance ratio is roughly 64 / 0.5 which is 128.

Note that this is not due to type-instabilities. One way to convince yourself of this is that there are no allocations required to run AD, which would most certainly not be the case were there type instabilities. Rather, the problems are to do with the overhead associated to our implementation of reverse-mode AD.

To see this, take a look at the optimised IR for foo!:

2 1 ── %1  = Base.arraysize(_3, 1)::Int64                                    │╻╷╷╷╷    macro expansion
  │    %2  = Base.slt_int(%1, 0)::Bool                                       ││╻╷╷╷╷    eachindex
  │    %3  = Core.ifelse(%2, 0, %1)::Int64                                   │││╻        axes1
  │    %4  = %new(Base.OneTo{Int64}, %3)::Base.OneTo{Int64}                  ││││┃││││    axes
  └───       goto #14 if not true                                            │╻        macro expansion
  2 ── %6  = Base.slt_int(0, %3)::Bool                                       ││╻        <
  └───       goto #12 if not %6                                              ││       
  3 ──       nothing::Nothing                                                │        
  4 ┄─ %9  = φ (#3 => 0, #11 => %27)::Int64                                  ││       
  │    %10 = Base.slt_int(%9, %3)::Bool                                      ││╻        <
  └───       goto #12 if not %10                                             ││       
  5 ── %12 = Base.add_int(%9, 1)::Int64                                      ││╻╷       simd_index
  └───       goto #9 if not false                                            │││╻        getindex
  6 ── %14 = Base.slt_int(0, %12)::Bool                                      ││││╻        >
  │    %15 = Base.sle_int(%12, %3)::Bool                                     ││││╻        <=
  │    %16 = Base.and_int(%14, %15)::Bool                                    ││││╻        &
  └───       goto #8 if not %16                                              ││││     
  7 ──       goto #9                                                         │        
  8 ──       invoke Base.throw_boundserror(%4::Base.OneTo{Int64}, %12::Int64)::Union{}
  └───       unreachable                                                     ││││     
  9 ┄─       goto #10                                                        │        
  10 ─       goto #11                                                        │        
  11 ─ %23 = Base.arrayref(false, _2, %12)::Float64                          ││╻╷       macro expansion
  │    %24 = Base.arrayref(false, _3, %12)::Float64                          │││┃        getindex
  │    %25 = Base.add_float(%23, %24)::Float64                               │││╻        +
  │          Base.arrayset(false, _2, %25, %12)::Vector{Float64}             │││╻        setindex!
  │    %27 = Base.add_int(%9, 1)::Int64                                      ││╻        +
  │          $(Expr(:loopinfo, Symbol("julia.simdloop"), nothing))::Nothing  │╻        macro expansion
  └───       goto #4                                                         ││       
  12 ┄       goto #14 if not false                                           ││       
  13 ─       nothing::Nothing                                                │        
5 14 ┄       return _2                                                       │        

The performance-critical chunk of the loop happens between %23 and %27. Tapir.jl does basically the same kind of thing for each of these lines, so we just look at %23:

%23_ = rrule!!(zero_fcodual(Base.arrayref), zero_fcodual(false), _2, %12)
%23 = %23[1]
push!(%23_pb_stack, %23[2])

In short, we run the rule, pull out the first element of the result, and push the pullback to the stack for use on the reverse-pass.

So there is at least one really large obvious source of overhead here: pushing to / popping from the stacks. If you take a look at the pullbacks for the arrayref calls, you'll see that they contain:

  1. (a reference to) the shadow of the array being referenced, and
  2. a copy of the index at which the forwards-pass references the array.

This information is necessary for AD, but

  1. the array being referenced and its shadow are loop invariants -- their value does not change at each iteration of the loop -- meaning that we're just pushing 4096 references to the same array to a stack and popping them, which is wasteful, and
  2. the index is an induction variable -- its value changes by a fixed known amount at each loop iteration, meaning that (in principle) we can just recompute it on the reverse-pass rather than storing it.

What's not obvious here, but is also important, is that the call to push! tends to get inlined and contains a branch. This prevents LLVM from vectorising the loop, thus prohibiting quite a lot of optimisation.

Now, Tapir.jl is implemented in such a way that, if the pullback for a particular function is a singleton / doesn't carry around any information, the associated pullback stack is eliminated entirely. Moreover, just reducing the amount of memory stored at each iteration should reduce memory pressure. Consequently, a good strategy for making progress is to figure out how to reduce the amount of stuff which gets stored in the pullback stacks. The two points noted above provide obvious starting points.

Making use of loop invariants

In short: ammend the rule interface such that the arguments to the forwards pass are also made available on the reverse pass.

For example, the arrayref rule is presently something along the lines of

function rrule!!(::CoDual{typeof(arrayref)}, inbounds::CoDual{Bool}, x::CoDual{Vector{Float64}}, ind::CoDual{Int})
    _ind = primal(ind)
    dx = tangent(x)
    function arrayref_pullback(dy)
        dx[_ind] += dy
        return NoRData(), NoRData(), NoRData(), NoRData()
    end
    return CoDual(primal(x)[_ind], tangent(x)[_ind]), arrayref_pullback
end

This skips some details, but the important point is that _ind and dx are closed over, and are therefore stored in arrayref_pullback.

Under the new interface, this would look something like

function rrule!!(::CoDual{typeof(arrayref)}, inbounds::CoDual{Bool}, x::CoDual{Vector{Float64}}, ind::CoDual{Int})
    function arrayref_pullback(dy, ::CoDual{typeof(arrayref)}, ::CoDual{Bool}, x::CoDual{Vector{Float64}}, ind::CoDual{Int})
        _ind = primal(ind)
        dx = tangent(x)
        dx[_ind] += dy
        return NoRData(), NoRData(), NoRData(), NoRData()
    end
    return CoDual(primal(x)[_ind], tangent(x)[_ind]), arrayref_pullback
end

In this version of the rule, arrayref_pullback is a singleton because it does not close over any data from the enclosing rrule!!.

So this interface change frees up Tapir.jl to provide the arguments on the reverse-pass in whichever way it pleases. In this particular example, both x and y are arguments to foo!, so applying this new interface recursively would give us direct access to them on the reverse pass by construction. A similar strategy could be employed for variables which aren't arguments by putting them in the storage shared by the forwards and reverse passes.

It's impossible to know for sure how much of an effect this would have, but doing this alone would more than halve the memory requirement for arrayref (a Vector{Float64} knows its address in memory and its length, which requires 16B of memory, vs an index which is just an Int which takes 8B of memory), and do even more for arrayset (it requires references to the primal array and to the shadow). Since the pullback for + is already a singleton in both the Float64 and Int case, this would more than halve the memory footprint of the loop.

Induction Variable Analysis

I won't address how we could make use of induction variable analysis here because I'm still trying to get my head around exactly how is easiest to go about it. Rather, just note that the above interface change is necessary in order to make use of the results of induction variable analysis -- the purpose of induction variable analysis would be to avoid having to store the index on each iteration of the loop, and to just re-compute it on the reverse pass, and give it to the pullbacks. The above change to the interface would permit this.

willtebbutt commented 1 week ago

Another obvious optimisation is to analyse the trip count, and pre-allocate the (necessary) pullback stacks in order to avoid branching during execution (i.e. checking that they're long enough to store the next pullback, and allocating more memory if not).

This is related to induction variable analysis, so we'd probably want to do that first.

Doing this kind of optimisation would enable vectorisation to happen more effectively in AD, as would could completely eliminate branching from a number of tight loops.

yebai commented 1 week ago

Good investigations; it's probably okay to keep this issue open instead of transferring discussions here into docs.