Open chriselrod opened 4 years ago
Thanks for tagging us in this. I'll need to stare at this for a while to get my head around it -- will get back to you as soon as possible.
I'm happy to answer any questions if it helps.
But basically, the code shows
LoopSet
objects, which LoopVectorization uses internally to represent loops.setup_call
to create the boiler plate wrapper for _avx_!
, the generated function that gets called under the hood. While this boiler plate is normally created via the macro @avx
, here we use @eval
.∂LoopSet
object from a LoopSet
. The ∂LoopSet
holds a forward and reverse pass LoopSet
.So, the central question here is about approaches to add LoopVectorization
support to various AD libraries? What is needed?
Something I should have emphasized is that an _avx_!
call:
_avx_!(Val{(0, 0, 0, LoopVectorization.unwrap(var"##Wvecwidth##"))}(), Tuple{:numericconstant, Symbol("##zero#260"), LoopVectorization.OperationStruct(0x0000000000000012, 0x0000000000000000, 0x0000000000000003, 0x0000000000000000, LoopVectorization.constant, 0x00, 0x01), :LoopVectorization, :setindex!, LoopVectorization.OperationStruct(0x0000000000000012, 0x0000000000000003, 0x0000000000000000, 0x0000000000000006, LoopVectorization.memstore, 0x01, 0x02), :LoopVectorization, :getindex, LoopVectorization.OperationStruct(0x0000000000000013, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, LoopVectorization.memload, 0x02, 0x03), :LoopVectorization, :getindex, LoopVectorization.OperationStruct(0x0000000000000032, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, LoopVectorization.memload, 0x03, 0x04), :LoopVectorization, :vfmadd_fast, LoopVectorization.OperationStruct(0x0000000000000132, 0x0000000000000003, 0x0000000000000000, 0x0000000000030401, LoopVectorization.compute, 0x00, 0x01), :LoopVectorization, :identity, LoopVectorization.OperationStruct(0x0000000000000012, 0x0000000000000003, 0x0000000000000000, 0x0000000000000005, LoopVectorization.compute, 0x00, 0x01)}, Tuple{LoopVectorization.ArrayRefStruct{:C,Symbol("##vptr##_C")}(0x0000000000000101, 0x0000000000000102, 0x0000000000000000), LoopVectorization.ArrayRefStruct{:A,Symbol("##vptr##_A")}(0x0000000000000101, 0x0000000000000103, 0x0000000000000000), LoopVectorization.ArrayRefStruct{:B,Symbol("##vptr##_B")}(0x0000000000000101, 0x0000000000000302, 0x0000000000000000)}, Tuple{0, Tuple{}, Tuple{}, Tuple{}, Tuple{}, Tuple{(1, LoopVectorization.IntOrFloat)}, Tuple{}}, Tuple{:m, :n, :k}, (var"##m_loop_lower_bound#254":var"##m_loop_upper_bound#255", var"##n_loop_lower_bound#257":var"##n_loop_upper_bound#258", var"##k_loop_lower_bound#262":var"##k_loop_upper_bound#263"), var"##vptr##_C", var"##vptr##_A", var"##vptr##_B")
Reconstructs a LoopSet
object from all the type information in the Tuple
arguments.
So we should be able to define a ChainRules.rrule
definition for _avx_!
that returns the appropriate pull backs. Perhaps I should work on that as the next example.
I would probably also need to do something about all the boiler plate code created surrounding the _avx_!
call.
begin
var"##loopm#253" = LoopVectorization.maybestaticrange(axes(C, 1))
var"##m_loop_lower_bound#254" = LoopVectorization.maybestaticfirst(var"##loopm#253")
var"##m_loop_upper_bound#255" = LoopVectorization.maybestaticlast(var"##loopm#253")
var"##loopn#256" = LoopVectorization.maybestaticrange(axes(C, 2))
var"##n_loop_lower_bound#257" = LoopVectorization.maybestaticfirst(var"##loopn#256")
var"##n_loop_upper_bound#258" = LoopVectorization.maybestaticlast(var"##loopn#256")
var"##loopk#261" = LoopVectorization.maybestaticrange(axes(B, 1))
var"##k_loop_lower_bound#262" = LoopVectorization.maybestaticfirst(var"##loopk#261")
var"##k_loop_upper_bound#263" = LoopVectorization.maybestaticlast(var"##loopk#261")
end
if LoopVectorization.check_args(C, A, B)
var"##vptr##_C" = LoopVectorization.stridedpointer(C)
var"##vptr##_A" = LoopVectorization.stridedpointer(A)
var"##vptr##_B" = LoopVectorization.stridedpointer(B)
begin
$(Expr(:gc_preserve, quote
begin
var"##Tloopeltype##" = promote_type(eltype(C), eltype(A), eltype(B))
var"##Wvecwidth##" = LoopVectorization.pick_vector_width_val(eltype(C), eltype(A), eltype(B))
end
LoopVectorization._avx_!(...)
So we should be able to define a ChainRules.rrule definition for avx! that returns the appropriate pull backs.
That sounds right to me. We support arbitary types, so we should be able to support what every you need. But i don't lnow that you even will need that.
Only limitation is we don't readily support calling back into the current AD. Though you can hard-code it to always call into Zygote. Not sure if that is needed.
we will probably want to be able to store temporaries between the forward and backwards passes generally.
The pullback is a closure, so if you have temporaries in the forward pass, and you reused them inside the pullback then that will not allocate.
This relied on internals and of course broke a long time ago, but anyone wanting to look at this later should try checking out https://github.com/JuliaSIMD/LoopVectorization.jl/tree/v0.8.21 (LoopVectorization 0.8.21), which was released on the same day I filed this issue, making it the most likely candidate to actually work.
@willtebbutt @oxinabox
I'd like suggestions/advice on how we can get autodiff support for
LoopVectorization.@avx
. Where / how can we integrate things to get things to work with the various AD tools?In this issue, I'll briefly describe LoopVectorization's API, and what ReverseDiffExpressions.jl can currently do. One of the goals of this issue is to define what sort of an API/user interface it should expose. (Quick note in case anyone looks at the src: most of the code in here is not
include
-ed; it's old and will soon be deleted.)LoopVectorization.@avx for ...
works by calling a generated functionLoopVectorization._avx_!
, and passing it a description of the loops (there is behindif check_args(...)
; ifcheck_args
returns false, it'll use a fallback loop).We can do this manually via:
(If we passed a quote as a second argument to
LoopVectorization.setup_call, this would have been the fallback for an
if check_args`)We can create a
∂LoopSet
object that holds forward and reverse pass loop sets via:Similarly, a reduction to a scalar
This example demonstrates an addititional obstacle: management of temporary arrays. While I will probably make it simply recompute values rather than store temporaries when the computations are cheap (e.g., less than 3x the cost of a load + store), we will probably want to be able to store temporaries between the forward and backwards passes generally.
Obligatory benchmark:
Currently, these benchmarks require the master branches of LoopVectorization, ReverseDiffExpressionsBase (unregistered), and ReverseDiffExpressions (unregistered).
I'll probably tag another release of LoopVectorization soon. The updates from the latest release are very minor.