Open asl opened 1 year ago
Tagging @rxwei @BradLarson @jkshtj
Any ideas, etc. would be welcome
Things to check immediately:
partial_apply
) gets additional inlining bonus in the inlining cost / benefit calculation. VJPs return tuple of original result and a pullback closure: do they get this additional bonus. And even if yes – do we want to increase this bonus for VJPs?And now the most interesting stuff wrt the code above. Here we're seeing that VJPs were fully inlined and therefore we are seeing the very final partial_apply
that are further stored in pullback tuple and passed to pullback. What we can do here is to rewrite linear map enum / linear map tuple: if we know that we are going to store partial_apply
there, then store the value it is closed upon. Then we can move partial_apply
to the corresponding BB of the pullback opening full inlining opportunities as there is an apply
of that partial_apply
there.
I am seeing two cases here:
Additionally, we might have linear functions that are captured as-is, without partial_apply
(like closure from vjpAdd
above), then we can move them into pullback directly, no need to capture the value at all.
Certainly, we'd need to handle the cases when not all VJPs are inlined. However, I believe this operation would be most beneficial for "leaf" pullbacks, as they are plenty, usually represent elementary operations (like addition, multiplication, etc.) or custom-registered derivatives.
So, for the case above we'd turn:
// foo(_:)
sil hidden [noinline] @$s6sincos3fooyS2fF : $@convention(thin) (Float) -> Float {
[global: read,write,copy,destroy,allocate,deinit_barrier]
// %0 "x" // users: %26, %24, %16, %12, %3, %2, %1
bb0(%0 : $Float):
debug_value %0 : $Float, let, name "x", argno 1 // id: %1
debug_value %0 : $Float, let, name "x", argno 1 // id: %2
%3 = struct_extract %0 : $Float, #Float._value // users: %13, %9, %5
%4 = float_literal $Builtin.FPIEEE32, 0x0 // 0 // user: %5
%5 = builtin "fcmp_olt_FPIEEE32"(%4 : $Builtin.FPIEEE32, %3 : $Builtin.FPIEEE32) : $Builtin.Int1 // user: %7
%6 = tuple () // users: %22, %8
cond_br %5, bb1, bb2 // id: %7
bb1: // Preds: bb0
%8 = enum $_AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0, #_AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0.bb0!enumelt, %6 : $() // user: %19
%9 = builtin "int_sin_FPIEEE32"(%3 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %10
%10 = struct $Float (%9 : $Builtin.FPIEEE32) // user: %18
// function_ref closure #1 in _vjpSin(_:)
%11 = function_ref @$s16_Differentiation7_vjpSinySf5value_S2fc8pullbacktSfFS2fcfU_ : $@convention(thin) (Float, Float) -> Float // user: %12
%12 = partial_apply [callee_guaranteed] %11(%0) : $@convention(thin) (Float, Float) -> Float // user: %19
%13 = builtin "int_cos_FPIEEE32"(%3 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %14
%14 = struct $Float (%13 : $Builtin.FPIEEE32) // user: %18
// function_ref closure #1 in _vjpCos(_:)
%15 = function_ref @$s16_Differentiation7_vjpCosySf5value_S2fc8pullbacktSfFS2fcfU_ : $@convention(thin) (Float, Float) -> Float // user: %16
%16 = partial_apply [callee_guaranteed] %15(%0) : $@convention(thin) (Float, Float) -> Float // user: %19
// function_ref closure #1 in static Float._vjpMultiply(lhs:rhs:)
%17 = function_ref @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float, Float, Float) -> (Float, Float) // user: %18
%18 = partial_apply [callee_guaranteed] %17(%14, %10) : $@convention(thin) (Float, Float, Float) -> (Float, Float) // user: %19
%19 = tuple $(predecessor: _AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, Float)) (%8, %12, %16, %18) // user: %20
%20 = enum $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0, #_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0.bb1!enumelt, %19 : $(predecessor: _AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, Float)) // user: %21
br bb3(%20 : $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0) // id: %21
bb2: // Preds: bb0
%22 = enum $_AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0, #_AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0.bb0!enumelt, %6 : $() // user: %29
// function_ref closure #1 in _vjpSin(_:)
%23 = function_ref @$s16_Differentiation7_vjpSinySf5value_S2fc8pullbacktSfFS2fcfU_ : $@convention(thin) (Float, Float) -> Float // user: %24
%24 = partial_apply [callee_guaranteed] %23(%0) : $@convention(thin) (Float, Float) -> Float // user: %29
// function_ref closure #1 in _vjpCos(_:)
%25 = function_ref @$s16_Differentiation7_vjpCosySf5value_S2fc8pullbacktSfFS2fcfU_ : $@convention(thin) (Float, Float) -> Float // user: %26
%26 = partial_apply [callee_guaranteed] %25(%0) : $@convention(thin) (Float, Float) -> Float // user: %29
// function_ref closure #1 in static Float._vjpAdd(lhs:rhs:)
%27 = function_ref @$sSf16_DifferentiationE7_vjpAdd3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float) -> (Float, Float) // user: %28
%28 = thin_to_thick_function %27 : $@convention(thin) (Float) -> (Float, Float) to $@callee_guaranteed (Float) -> (Float, Float) // user: %29
%29 = tuple $(predecessor: _AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, Float)) (%22, %24, %26, %28) // user: %30
%30 = enum $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0, #_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0.bb2!enumelt, %29 : $(predecessor: _AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, Float)) // user: %31
br bb3(%30 : $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0) // id: %31
// %32 // user: %37
bb3(%32 : $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0): // Preds: bb1 bb2
// function_ref pullback of f(_:)
%33 = function_ref @$s6sincos1fyS2fFTJpSpSr : $@convention(thin) (Float, @owned _AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0) -> Float // user: %37
%34 = integer_literal $Builtin.Int64, 1 // user: %35
%35 = builtin "sitofp_Int64_FPIEEE32"(%34 : $Builtin.Int64) : $Builtin.FPIEEE32 // user: %36
%36 = struct $Float (%35 : $Builtin.FPIEEE32) // user: %37
%37 = apply %33(%36, %32) : $@convention(thin) (Float, @owned _AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0) -> Float // user: %38
return %37 : $Float // id: %38
} // end sil function '$s6sincos3fooyS2fF'
into:
enum _AD__$s6sincos1fyS2fF_bb0__Pred__src_0_wrt_0 {
}
enum _AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0 {
case bb0(())
}
enum _AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0 {
case bb0(())
}
enum _AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0 {
case bb2((predecessor: _AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0, Float, Float, (Float, Float)))
case bb1((predecessor: _AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0, Float, Float))
}
// foo(_:)
sil hidden [noinline] @$s6sincos3fooyS2fF : $@convention(thin) (Float) -> Float {
[global: read,write,copy,destroy,allocate,deinit_barrier]
// %0 "x" // users: %26, %24, %16, %12, %3, %2, %1
bb0(%0 : $Float):
debug_value %0 : $Float, let, name "x", argno 1 // id: %1
debug_value %0 : $Float, let, name "x", argno 1 // id: %2
%3 = struct_extract %0 : $Float, #Float._value // users: %13, %9, %5
%4 = float_literal $Builtin.FPIEEE32, 0x0 // 0 // user: %5
%5 = builtin "fcmp_olt_FPIEEE32"(%4 : $Builtin.FPIEEE32, %3 : $Builtin.FPIEEE32) : $Builtin.Int1 // user: %7
%6 = tuple () // users: %22, %8
cond_br %5, bb1, bb2 // id: %7
bb1: // Preds: bb0
%8 = enum $_AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0, #_AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0.bb0!enumelt, %6 : $() // user: %19
%9 = builtin "int_sin_FPIEEE32"(%3 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %10
%10 = struct $Float (%9 : $Builtin.FPIEEE32) // user: %18
%13 = builtin "int_cos_FPIEEE32"(%3 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %14
%14 = struct $Float (%13 : $Builtin.FPIEEE32)
%newt = tuple $(Float, Float) (%14, %10)
%19 = tuple $(predecessor: _AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0, Float, Float, (Float, Float)) (%8, %0, %0, %newt) // user: %20
%20 = enum $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0, #_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0.bb1!enumelt, %19 : $(predecessor: _AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0, Float, (Float, Float)) // user: %21
br bb3(%20 : $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0) // id: %21
bb2: // Preds: bb0
%22 = enum $_AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0, #_AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0.bb0!enumelt, %6 : $() // user: %29
%29 = tuple $(predecessor: _AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0, Float, Float) (%22, %0, %0) // user: %30
%30 = enum $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0, #_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0.bb2!enumelt, %29 : $(predecessor: _AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0, Float, Float) // user: %31
br bb3(%30 : $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0) // id: %31
// %32 // user: %37
bb3(%32 : $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0): // Preds: bb1 bb2
// function_ref pullback of f(_:)
%33 = function_ref @$s6sincos1fyS2fFTJpSpSr : $@convention(thin) (Float, @owned _AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0) -> Float // user: %37
%34 = integer_literal $Builtin.Int64, 1 // user: %35
%35 = builtin "sitofp_Int64_FPIEEE32"(%34 : $Builtin.Int64) : $Builtin.FPIEEE32 // user: %36
%36 = struct $Float (%35 : $Builtin.FPIEEE32) // user: %37
%37 = apply %33(%36, %32) : $@convention(thin) (Float, @owned _AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0) -> Float // user: %38
return %37 : $Float // id: %38
} // end sil function '$s6sincos3fooyS2fF'
Thoughts?
~Based on discussions w/ Anton, there's 4 high level tasks that we can look into (listed in order of priority)~ ~1. Inlining cost model of functions receiving closures.~ ~ What is it? What do we need to tweak so that simple pullbacks without any control-flow can get fully inlined?~ ~2. Inlining cost model of functions returning closures.~ ~ What is it? Is it sufficient for optimizing most cases or does it need to be tweaked to unlock more optimizations.~ ~* Based on the example here it seems like it does optimize VJPs (which return the pullback closure) correctly.~ ~3. *Reducing overall closure context allocations in VJPs.~ ~ pullbacks with control flow, where the received closure arguments are hidden behind enums/tuples.~ ~ Instead of passing in~ ~ Reducing overall closure context allocations in pullback w/ loops might be similar but more tricky.~ ~4. Interprocedural constant propagation to VJPs from differential operators.**~
Look below for list of subtasks.
@jkshtj 3. is not about inlining really. And this particular issue is not about inlining. The proposed optimization indeed enables some inlining opportunities, but again, it's not about inlining, but rather changing of the code layout – instead of capturing closures capture values being closed.
The real subtasks are:
Check the inlining cost / benefit model for autodiff-generated functions (we need to ensure they receive benefit bonus):
Implement the closure optimization that is specialized towards the linear map tuples / enums produced by autodiff. In particular, as in example above, if we see the particular closure (partial_apply
), then instead of storing the closure in the tuple, store the closed value. And move partial_apply
down to the apply site (no need to fold, there are existing passes to do this). And we know the place of use due to the way how the linear map tuples and branch tracing enums are generated.
One immediate thing is to improve constant propagation: obviously the condition in the function is already true (though, this optimization is done further on LLVM IR level, however, we still have 3 closures there and therefore – 3 context allocations for).
@asl I compiled non-autodiff version of our example here and it looks like the compiler does not eliminate conditional branches even if the value of the condition is known. The code below -
func f(_ x: Float) -> Float {
if (x > 0) {
return x*x
} else {
return x+x
}
}
@inline(never)
func foo() -> Float {
f(4)
}
Compiles to this -
// foo()
sil hidden [noinline] @$s4test3fooSfyF : $@convention(thin) () -> Float {
[global: ]
bb0:
%0 = float_literal $Builtin.FPIEEE32, 0x40800000 // 4 // users: %1, %3
debug_value %0 : $Builtin.FPIEEE32, let, name "x", argno 1, type $Float, expr op_fragment:#Float._value // id: %1
%2 = float_literal $Builtin.FPIEEE32, 0x0 // 0 // user: %3
%3 = builtin "fcmp_olt_FPIEEE32"(%2 : $Builtin.FPIEEE32, %0 : $Builtin.FPIEEE32) : $Builtin.Int1 // user: %4
cond_br %3, bb1, bb2 // id: %4
bb1: // Preds: bb0
%5 = float_literal $Builtin.FPIEEE32, 0x41800000 // 16 // user: %6
%6 = struct $Float (%5 : $Builtin.FPIEEE32) // user: %7
br bb3(%6 : $Float) // id: %7
bb2: // Preds: bb0
%8 = float_literal $Builtin.FPIEEE32, 0x41000000 // 8 // user: %9
%9 = struct $Float (%8 : $Builtin.FPIEEE32) // user: %10
br bb3(%9 : $Float) // id: %10
// %11 // user: %12
bb3(%11 : $Float): // Preds: bb2 bb1
return %11 : $Float // id: %12
} // end sil function '$s4test3fooSfyF'
I mostly looked through the SimplifyCfg optimization pass to see if something is just blocking this optimization, but it looks like it's not implemented at all.
Looks like that the constant folder doesn't handle float point comparisons. Reproducible with
func f() -> Bool {
let x: Float = 1.0
return x > 0
}
Even when compiled with -O the generated SIL still contains the not folded fcmp_olt
builtin.
It should be easy to add this to constantFoldCompare
https://github.com/apple/swift/blob/3eea9d680d334841c8831fab87c1b2fe3add95d2/lib/SILOptimizer/Utils/ConstantFolding.cpp#L354
Is someone motivated to add this?
@eeckstein
Thanks for pointing that out! Just by looking at SimplifyCFG I could not understand how "insufficient" constant-propagation was blocking the optimization, but dumping the SIL around opt passes shows, that's what's happening. I'll read the code a bit more to see where exactly in simply-cfg this optimization is fully realized.
Is someone motivated to add this?
Are you asking for more use-cases (in addition to the one in this example)? Or are you asking if someone would be motivated to do the work? In which case I can definitely go ahead and do it.
In which case I can definitely go ahead and do it.
That would be great!
Recent autodiff changes that removed use of pullback structs in favor of direct pullback application to nested calls enabled lots of possible inlining and specialization opportunities in VJPs / pullbacks.
However, the situation is not that great for the functions with control flow as generated code with enums / tuples effectively prevents inlining and function substitution as forward path becomes quite "opaque".
Consider the following small example:
We're currently generating:
It would be great to find ways to simplify this further.
One immediate thing is to improve constant propagation: obviously the condition in the function is already true (though, this optimization is done further on LLVM IR level, however, we still have 3 closures there and therefore – 3 context allocations for).
Though, this certainly won't help the common case:
It looks like the
apply
of pullback closure was correctly turned into apply of the pullback itself. However, it was not inlined further. It would be interesting to understand, why. And even if the pullback would be inlined, can we do some kind of CSE and optimize the diamond-shape CFG hoisting the code into the corresponding BBs?