chriselrod / ReverseDiffExpressions.jl

Library facilitating source to source reverse diff on Julia expressions; focus on probability distributions.
MIT License
2 stars 0 forks source link

How to add ChainRules or Zygote support for LoopVectorization? #2

Open chriselrod opened 4 years ago

chriselrod commented 4 years ago

@willtebbutt @oxinabox

I'd like suggestions/advice on how we can get autodiff support for LoopVectorization.@avx. Where / how can we integrate things to get things to work with the various AD tools?

In this issue, I'll briefly describe LoopVectorization's API, and what ReverseDiffExpressions.jl can currently do. One of the goals of this issue is to define what sort of an API/user interface it should expose. (Quick note in case anyone looks at the src: most of the code in here is not include-ed; it's old and will soon be deleted.)

LoopVectorization.@avx for ... works by calling a generated function LoopVectorization._avx_!, and passing it a description of the loops (there is behind if check_args(...); if check_args returns false, it'll use a fallback loop).

julia> using LoopVectorization

julia> @macroexpand @avx for m ∈ axes(C,1), n ∈ axes(C,2)
           C[m,n] = zero(eltype(B))
           for k ∈ axes(B,1)
               C[m,n] += A[m,k] * B[k,n]
           end
       end
quote
    begin
        var"##loopm#253" = LoopVectorization.maybestaticrange(axes(C, 1))
        var"##m_loop_lower_bound#254" = LoopVectorization.maybestaticfirst(var"##loopm#253")
        var"##m_loop_upper_bound#255" = LoopVectorization.maybestaticlast(var"##loopm#253")
        var"##loopn#256" = LoopVectorization.maybestaticrange(axes(C, 2))
        var"##n_loop_lower_bound#257" = LoopVectorization.maybestaticfirst(var"##loopn#256")
        var"##n_loop_upper_bound#258" = LoopVectorization.maybestaticlast(var"##loopn#256")
        var"##loopk#261" = LoopVectorization.maybestaticrange(axes(B, 1))
        var"##k_loop_lower_bound#262" = LoopVectorization.maybestaticfirst(var"##loopk#261")
        var"##k_loop_upper_bound#263" = LoopVectorization.maybestaticlast(var"##loopk#261")
    end
    if LoopVectorization.check_args(C, A, B)
        var"##vptr##_C" = LoopVectorization.stridedpointer(C)
        var"##vptr##_A" = LoopVectorization.stridedpointer(A)
        var"##vptr##_B" = LoopVectorization.stridedpointer(B)
        begin
            $(Expr(:gc_preserve, quote
    begin
        var"##Tloopeltype##" = promote_type(eltype(C), eltype(A), eltype(B))
        var"##Wvecwidth##" = LoopVectorization.pick_vector_width_val(eltype(C), eltype(A), eltype(B))
    end
    LoopVectorization._avx_!(Val{(0, 0, 0, LoopVectorization.unwrap(var"##Wvecwidth##"))}(), Tuple{:numericconstant, Symbol("##zero#260"), LoopVectorization.OperationStruct(0x0000000000000012, 0x0000000000000000, 0x0000000000000003, 0x0000000000000000, LoopVectorization.constant, 0x00, 0x01), :LoopVectorization, :setindex!, LoopVectorization.OperationStruct(0x0000000000000012, 0x0000000000000003, 0x0000000000000000, 0x0000000000000006, LoopVectorization.memstore, 0x01, 0x02), :LoopVectorization, :getindex, LoopVectorization.OperationStruct(0x0000000000000013, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, LoopVectorization.memload, 0x02, 0x03), :LoopVectorization, :getindex, LoopVectorization.OperationStruct(0x0000000000000032, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, LoopVectorization.memload, 0x03, 0x04), :LoopVectorization, :vfmadd_fast, LoopVectorization.OperationStruct(0x0000000000000132, 0x0000000000000003, 0x0000000000000000, 0x0000000000030401, LoopVectorization.compute, 0x00, 0x01), :LoopVectorization, :identity, LoopVectorization.OperationStruct(0x0000000000000012, 0x0000000000000003, 0x0000000000000000, 0x0000000000000005, LoopVectorization.compute, 0x00, 0x01)}, Tuple{LoopVectorization.ArrayRefStruct{:C,Symbol("##vptr##_C")}(0x0000000000000101, 0x0000000000000102, 0x0000000000000000), LoopVectorization.ArrayRefStruct{:A,Symbol("##vptr##_A")}(0x0000000000000101, 0x0000000000000103, 0x0000000000000000), LoopVectorization.ArrayRefStruct{:B,Symbol("##vptr##_B")}(0x0000000000000101, 0x0000000000000302, 0x0000000000000000)}, Tuple{0, Tuple{}, Tuple{}, Tuple{}, Tuple{}, Tuple{(1, LoopVectorization.IntOrFloat)}, Tuple{}}, Tuple{:m, :n, :k}, (var"##m_loop_lower_bound#254":var"##m_loop_upper_bound#255", var"##n_loop_lower_bound#257":var"##n_loop_upper_bound#258", var"##k_loop_lower_bound#262":var"##k_loop_upper_bound#263"), var"##vptr##_C", var"##vptr##_A", var"##vptr##_B")
    nothing
end, :C, :A, :B))
        end
    else
        $(Expr(:inbounds, true))
        local var"#18#val" = for m = axes(C, 1), n = axes(C, 2)
                    #= REPL[5]:2 =#
                    C[m, n] = zero(eltype(B))
                    #= REPL[5]:3 =#
                    for k = axes(B, 1)
                        #= REPL[5]:4 =#
                        begin
                            #= fastmath.jl:115 =#
                            var"##267" = C
                            #= fastmath.jl:116 =#
                            (var"##268", var"##269") = (m, n)
                            #= fastmath.jl:117 =#
                            var"##267"[var"##268", var"##269"] = Base.FastMath.add_fast(var"##267"[var"##268", var"##269"], Base.FastMath.mul_fast(A[m, k], B[k, n]))
                        end
                    end
                end
        $(Expr(:inbounds, :pop))
        var"#18#val"
    end
end

We can do this manually via:

julia> AmulBq = :(for m ∈ axes(C,1), n ∈ axes(C,2)
           C[m,n] = zero(eltype(B))
           for k ∈ axes(B,1)
               C[m,n] += A[m,k] * B[k,n]
           end
       end);

julia> lsAmulB = LoopVectorization.LoopSet(AmulBq);

julia> LoopVectorization.setup_call(lsAmulB)
quote
    begin
        var"##loopm#270" = LoopVectorization.maybestaticrange(axes(C, 1))
        var"##m_loop_lower_bound#271" = LoopVectorization.maybestaticfirst(var"##loopm#270")
        var"##m_loop_upper_bound#272" = LoopVectorization.maybestaticlast(var"##loopm#270")
        var"##loopn#273" = LoopVectorization.maybestaticrange(axes(C, 2))
        var"##n_loop_lower_bound#274" = LoopVectorization.maybestaticfirst(var"##loopn#273")
        var"##n_loop_upper_bound#275" = LoopVectorization.maybestaticlast(var"##loopn#273")
        var"##loopk#278" = LoopVectorization.maybestaticrange(axes(B, 1))
        var"##k_loop_lower_bound#279" = LoopVectorization.maybestaticfirst(var"##loopk#278")
        var"##k_loop_upper_bound#280" = LoopVectorization.maybestaticlast(var"##loopk#278")
    end
    begin
        var"##vptr##_C" = LoopVectorization.stridedpointer(C)
        var"##vptr##_A" = LoopVectorization.stridedpointer(A)
        var"##vptr##_B" = LoopVectorization.stridedpointer(B)
        begin
            #= /home/chriselrod/.julia/dev/LoopVectorization/src/lowering.jl:546 =# GC.@preserve C A B begin
                    begin
                        var"##Tloopeltype##" = promote_type(eltype(C), eltype(A), eltype(B))
                        var"##Wvecwidth##" = LoopVectorization.pick_vector_width_val(eltype(C), eltype(A), eltype(B))
                    end
                    LoopVectorization._avx_!(Val{(0, 0, 0, LoopVectorization.unwrap(var"##Wvecwidth##"))}(), Tuple{:numericconstant, Symbol("##zero#277"), LoopVectorization.OperationStruct(0x0000000000000012, 0x0000000000000000, 0x0000000000000003, 0x0000000000000000, LoopVectorization.constant, 0x00, 0x01), :LoopVectorization, :setindex!, LoopVectorization.OperationStruct(0x0000000000000012, 0x0000000000000003, 0x0000000000000000, 0x0000000000000006, LoopVectorization.memstore, 0x01, 0x02), :LoopVectorization, :getindex, LoopVectorization.OperationStruct(0x0000000000000013, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, LoopVectorization.memload, 0x02, 0x03), :LoopVectorization, :getindex, LoopVectorization.OperationStruct(0x0000000000000032, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, LoopVectorization.memload, 0x03, 0x04), :LoopVectorization, :vfmadd_fast, LoopVectorization.OperationStruct(0x0000000000000132, 0x0000000000000003, 0x0000000000000000, 0x0000000000030401, LoopVectorization.compute, 0x00, 0x01), :LoopVectorization, :identity, LoopVectorization.OperationStruct(0x0000000000000012, 0x0000000000000003, 0x0000000000000000, 0x0000000000000005, LoopVectorization.compute, 0x00, 0x01)}, Tuple{LoopVectorization.ArrayRefStruct{:C,Symbol("##vptr##_C")}(0x0000000000000101, 0x0000000000000102, 0x0000000000000000), LoopVectorization.ArrayRefStruct{:A,Symbol("##vptr##_A")}(0x0000000000000101, 0x0000000000000103, 0x0000000000000000), LoopVectorization.ArrayRefStruct{:B,Symbol("##vptr##_B")}(0x0000000000000101, 0x0000000000000302, 0x0000000000000000)}, Tuple{0, Tuple{}, Tuple{}, Tuple{}, Tuple{}, Tuple{(1, LoopVectorization.IntOrFloat)}, Tuple{}}, Tuple{:m, :n, :k}, (var"##m_loop_lower_bound#271":var"##m_loop_upper_bound#272", var"##n_loop_lower_bound#274":var"##n_loop_upper_bound#275", var"##k_loop_lower_bound#279":var"##k_loop_upper_bound#280"), var"##vptr##_C", var"##vptr##_A", var"##vptr##_B")
                    nothing
                end
        end
    end
end

(If we passed a quote as a second argument to LoopVectorization.setup_call, this would have been the fallback for anif check_args`)

We can create a ∂LoopSet object that holds forward and reverse pass loop sets via:

julia> using ReverseDiffExpressions, ReverseDiffExpressions.LoopSetDerivatives

julia> tracked_vars = Set([:A, :B]);

julia> ∂lsAmulB = LoopSetDerivatives.∂LoopSet(lsAmulB, tracked_vars);

julia> @eval AmulB_forward!(C, A, B) = $(LoopVectorization.setup_call(∂lsAmulB.fls))
AmulB_forward! (generic function with 1 method)

julia> @eval AmulB_reverse!(var"A##BAR##", var"B##BAR##", var"C##BAR##", A, B, C) = $(LoopVectorization.setup_call(∂lsAmulB.rls))
AmulB_reverse! (generic function with 1 method)

julia> M = K = N = 72;

julia> A = rand(M, K); B = rand(K, N); C = Matrix{Float64}(undef, M, N);

julia> AmulB_forward!(C, A, B); C ≈ A * B
true

julia> Cbar = rand(M, N); Abar = zero(A); Bbar = zero(B);

julia> AmulB_reverse!(Abar, Bbar, Cbar, A, B, C); Abar ≈ Cbar * B' # closed form
true

julia> Bbar ≈ A' * Cbar # closed form
true

Similarly, a reduction to a scalar

tq = :(for i in eachindex(C)
    target += log1p(0.25*(C[i] - 18)^2)
end);

lst = LoopVectorization.LoopSet(tq);

∂lst = LoopSetDerivatives.∂LoopSet(lst, Set([:C]));

This example demonstrates an addititional obstacle: management of temporary arrays. While I will probably make it simply recompute values rather than store temporaries when the computations are cheap (e.g., less than 3x the cost of a load + store), we will probably want to be able to store temporaries between the forward and backwards passes generally.

julia> ∂lst.temparrays
2-element Array{LoopVectorization.ArrayReferenceMeta,1}:
 LoopVectorization.ArrayReferenceMeta(LoopVectorization.ArrayReference(Symbol("##temporaryarray#404"), [:i], Int8[0]), Bool[1], Symbol("##vptr##_##temporaryarray#404"))
 LoopVectorization.ArrayReferenceMeta(LoopVectorization.ArrayReference(Symbol("##temporaryarray#409"), [:i], Int8[0]), Bool[1], Symbol("##vptr##_##temporaryarray#409"))
julia> temp1qn = LoopVectorization.name(∂lst.temparrays[1]);

julia> temp2qn = LoopVectorization.name(∂lst.temparrays[2]);

julia> @eval function t_forward!($temp1qn, $temp2qn, C)
           target = 0.0
           $(LoopVectorization.setup_call(∂lst.fls))
           target
       end
t_forward! (generic function with 1 method)

julia> @eval t_reverse!(var"C##BAR##", var"target##BAR##", $temp1qn, $temp2qn, C) = $(LoopVectorization.setup_call(∂lst.rls))
t_reverse! (generic function with 1 method)

julia> temp1 = similar(C); temp2 = similar(C);

julia> function loop_diff_example!(Cb, Ab, Bb, temp1, temp2, C, A, B)
          AmulB_forward!(C, A, B)
          t = -2.5 * t_forward!(temp1, temp2, C)
          fill!(Cb, 0) # These fills can be made lazy with a ZeroIntitialized wrapper
          t_reverse!(Cb, -2.5, temp1, temp2, C)
          fill!(Ab, 0); fill!(Bb, 0);
          AmulB_reverse!(Ab, Bb, Cb, A, B, C)
          t
       end
loop_diff_example! (generic function with 1 method)

julia> diff_example(A, B) = -2.5sum(c -> log1p(0.25*(c - 18)^2), A * B)
diff_example (generic function with 1 method)

julia> loop_diff_example!(Cbar, Abar, Bbar, temp1, temp2, C, A, B)
-6233.279504712384

julia> diff_example(A, B)
-6233.279504712382

julia> Ab, Bb = Zygote.gradient(diff_example, A, B)
([-35.47245105939837 -36.82736614107259 … -35.75041188397073 -37.10989215069272; -3.2505037185682966 -2.9903189244189283 … -0.539397415346338 -3.8255420209357704; … ; -27.205328260777737 -29.017228952257252 … -28.08338041611006 -29.097446421822415; -34.144150300833175 -34.37638675813865 … -36.40192762020005 -35.462327518544186], [16.30863481482425 -12.981458571271965 … -31.748418425531508 33.360925272640394; 22.35533973543051 -13.167204352167206 … -35.62760400880121 35.8395693069914; … ; 18.905579589637522 -11.525628608957025 … -33.54210166756797 37.22284630871768; 22.698958440592133 -12.464500964355727 … -37.1305633915512 37.4114781431934])

julia> Abar, Bbar
([-35.47245105939837 -36.82736614107259 … -35.75041188397073 -37.10989215069272; -3.2505037185682975 -2.9903189244189288 … -0.5393974153463376 -3.8255420209357696; … ; -27.205328260777737 -29.017228952257252 … -28.08338041611006 -29.097446421822415; -34.144150300833175 -34.37638675813865 … -36.40192762020005 -35.46232751854418], [16.30863481482425 -12.981458571271965 … -31.748418425531504 33.3609252726404; 22.3553397354305 -13.167204352167207 … -35.627604008801214 35.8395693069914; … ; 18.905579589637526 -11.525628608957026 … -33.54210166756796 37.22284630871767; 22.698958440592133 -12.464500964355729 … -37.130563391551206 37.411478143193406])

julia> Ab ≈ Abar, Bb ≈ Bbar
(true, true)

Obligatory benchmark:

julia> @benchmark Zygote.gradient(diff_example, $A, $B)
BenchmarkTools.Trial:
  memory estimate:  853.05 KiB
  allocs estimate:  15626
  --------------
  minimum time:     318.435 μs (0.00% GC)
  median time:      335.337 μs (0.00% GC)
  mean time:        379.646 μs (11.26% GC)
  maximum time:     2.956 ms (85.11% GC)
  --------------
  samples:          10000
  evals/sample:     1

julia> @benchmark loop_diff_example!($Cbar, $Abar, $Bbar, $temp1, $temp2, $C, $A, $B)
BenchmarkTools.Trial:
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     49.135 μs (0.00% GC)
  median time:      49.385 μs (0.00% GC)
  mean time:        49.445 μs (0.00% GC)
  maximum time:     107.355 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1

Currently, these benchmarks require the master branches of LoopVectorization, ReverseDiffExpressionsBase (unregistered), and ReverseDiffExpressions (unregistered).

I'll probably tag another release of LoopVectorization soon. The updates from the latest release are very minor.

willtebbutt commented 4 years ago

Thanks for tagging us in this. I'll need to stare at this for a while to get my head around it -- will get back to you as soon as possible.

chriselrod commented 4 years ago

I'm happy to answer any questions if it helps.

But basically, the code shows

  1. Manually creating LoopSet objects, which LoopVectorization uses internally to represent loops.
  2. Using setup_call to create the boiler plate wrapper for _avx_!, the generated function that gets called under the hood. While this boiler plate is normally created via the macro @avx, here we use @eval.
  3. How a to create a ∂LoopSet object from a LoopSet. The ∂LoopSet holds a forward and reverse pass LoopSet.

So, the central question here is about approaches to add LoopVectorization support to various AD libraries? What is needed?

Something I should have emphasized is that an _avx_! call:

_avx_!(Val{(0, 0, 0, LoopVectorization.unwrap(var"##Wvecwidth##"))}(), Tuple{:numericconstant, Symbol("##zero#260"), LoopVectorization.OperationStruct(0x0000000000000012, 0x0000000000000000, 0x0000000000000003, 0x0000000000000000, LoopVectorization.constant, 0x00, 0x01), :LoopVectorization, :setindex!, LoopVectorization.OperationStruct(0x0000000000000012, 0x0000000000000003, 0x0000000000000000, 0x0000000000000006, LoopVectorization.memstore, 0x01, 0x02), :LoopVectorization, :getindex, LoopVectorization.OperationStruct(0x0000000000000013, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, LoopVectorization.memload, 0x02, 0x03), :LoopVectorization, :getindex, LoopVectorization.OperationStruct(0x0000000000000032, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, LoopVectorization.memload, 0x03, 0x04), :LoopVectorization, :vfmadd_fast, LoopVectorization.OperationStruct(0x0000000000000132, 0x0000000000000003, 0x0000000000000000, 0x0000000000030401, LoopVectorization.compute, 0x00, 0x01), :LoopVectorization, :identity, LoopVectorization.OperationStruct(0x0000000000000012, 0x0000000000000003, 0x0000000000000000, 0x0000000000000005, LoopVectorization.compute, 0x00, 0x01)}, Tuple{LoopVectorization.ArrayRefStruct{:C,Symbol("##vptr##_C")}(0x0000000000000101, 0x0000000000000102, 0x0000000000000000), LoopVectorization.ArrayRefStruct{:A,Symbol("##vptr##_A")}(0x0000000000000101, 0x0000000000000103, 0x0000000000000000), LoopVectorization.ArrayRefStruct{:B,Symbol("##vptr##_B")}(0x0000000000000101, 0x0000000000000302, 0x0000000000000000)}, Tuple{0, Tuple{}, Tuple{}, Tuple{}, Tuple{}, Tuple{(1, LoopVectorization.IntOrFloat)}, Tuple{}}, Tuple{:m, :n, :k}, (var"##m_loop_lower_bound#254":var"##m_loop_upper_bound#255", var"##n_loop_lower_bound#257":var"##n_loop_upper_bound#258", var"##k_loop_lower_bound#262":var"##k_loop_upper_bound#263"), var"##vptr##_C", var"##vptr##_A", var"##vptr##_B")

Reconstructs a LoopSet object from all the type information in the Tuple arguments. So we should be able to define a ChainRules.rrule definition for _avx_! that returns the appropriate pull backs. Perhaps I should work on that as the next example.

I would probably also need to do something about all the boiler plate code created surrounding the _avx_! call.

    begin
        var"##loopm#253" = LoopVectorization.maybestaticrange(axes(C, 1))
        var"##m_loop_lower_bound#254" = LoopVectorization.maybestaticfirst(var"##loopm#253")
        var"##m_loop_upper_bound#255" = LoopVectorization.maybestaticlast(var"##loopm#253")
        var"##loopn#256" = LoopVectorization.maybestaticrange(axes(C, 2))
        var"##n_loop_lower_bound#257" = LoopVectorization.maybestaticfirst(var"##loopn#256")
        var"##n_loop_upper_bound#258" = LoopVectorization.maybestaticlast(var"##loopn#256")
        var"##loopk#261" = LoopVectorization.maybestaticrange(axes(B, 1))
        var"##k_loop_lower_bound#262" = LoopVectorization.maybestaticfirst(var"##loopk#261")
        var"##k_loop_upper_bound#263" = LoopVectorization.maybestaticlast(var"##loopk#261")
    end
    if LoopVectorization.check_args(C, A, B)
        var"##vptr##_C" = LoopVectorization.stridedpointer(C)
        var"##vptr##_A" = LoopVectorization.stridedpointer(A)
        var"##vptr##_B" = LoopVectorization.stridedpointer(B)
        begin
            $(Expr(:gc_preserve, quote
    begin
        var"##Tloopeltype##" = promote_type(eltype(C), eltype(A), eltype(B))
        var"##Wvecwidth##" = LoopVectorization.pick_vector_width_val(eltype(C), eltype(A), eltype(B))
    end
    LoopVectorization._avx_!(...)
oxinabox commented 4 years ago

So we should be able to define a ChainRules.rrule definition for avx! that returns the appropriate pull backs.

That sounds right to me. We support arbitary types, so we should be able to support what every you need. But i don't lnow that you even will need that.

Only limitation is we don't readily support calling back into the current AD. Though you can hard-code it to always call into Zygote. Not sure if that is needed.

we will probably want to be able to store temporaries between the forward and backwards passes generally.

The pullback is a closure, so if you have temporaries in the forward pass, and you reused them inside the pullback then that will not allocate.

chriselrod commented 2 years ago

This relied on internals and of course broke a long time ago, but anyone wanting to look at this later should try checking out https://github.com/JuliaSIMD/LoopVectorization.jl/tree/v0.8.21 (LoopVectorization 0.8.21), which was released on the same day I filed this issue, making it the most likely candidate to actually work.