swiftlang / swift

The Swift Programming Language
https://swift.org
Apache License 2.0
67.65k stars 10.38k forks source link

Investigate possible optimization opportunities for autodiff code with control flow #68901

Open asl opened 1 year ago

asl commented 1 year ago

Recent autodiff changes that removed use of pullback structs in favor of direct pullback application to nested calls enabled lots of possible inlining and specialization opportunities in VJPs / pullbacks.

However, the situation is not that great for the functions with control flow as generated code with enums / tuples effectively prevents inlining and function substitution as forward path becomes quite "opaque".

Consider the following small example:

import _Differentiation
import Darwin

@differentiable(reverse)
func f(_ x: Float) -> Float {
  if (x > 0) {
    return sin(x) * cos(x)
  } else {
    return sin(x) + cos(x)
  }
}

@inline(never)
func foo() -> Float {
  gradient(at: Float(4), of: f(x) )
}

We're currently generating:

// foo()
sil hidden [noinline] @$s6sincos3fooSfyF : $@convention(thin) () -> Float {
[global: read,write,copy,destroy,allocate,deinit_barrier]
bb0:
  %0 = float_literal $Builtin.FPIEEE32, 0x40800000 // 4 // users: %4, %8, %12, %1
  %1 = struct $Float (%0 : $Builtin.FPIEEE32)     // users: %25, %23, %15, %11, %2
  debug_value %1 : $Float, let, name "x", argno 1 // id: %2
  %3 = float_literal $Builtin.FPIEEE32, 0x0 // 0  // user: %4
  %4 = builtin "fcmp_olt_FPIEEE32"(%3 : $Builtin.FPIEEE32, %0 : $Builtin.FPIEEE32) : $Builtin.Int1 // user: %6
  %5 = tuple ()                                   // users: %21, %7
  cond_br %4, bb1, bb2                            // id: %6

bb1:                                              // Preds: bb0
  %7 = enum $_AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0, #_AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0.bb0!enumelt, %5 : $() // user: %18
  %8 = builtin "int_sin_FPIEEE32"(%0 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %9
  %9 = struct $Float (%8 : $Builtin.FPIEEE32)     // user: %17
  // function_ref closure #1 in _vjpSin(_:)
  %10 = function_ref @$s16_Differentiation7_vjpSinySf5value_S2fc8pullbacktSfFS2fcfU_ : $@convention(thin) (Float, Float) -> Float // user: %11
  %11 = partial_apply [callee_guaranteed] %10(%1) : $@convention(thin) (Float, Float) -> Float // user: %18
  %12 = builtin "int_cos_FPIEEE32"(%0 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %13
  %13 = struct $Float (%12 : $Builtin.FPIEEE32)   // user: %17
  // function_ref closure #1 in _vjpCos(_:)
  %14 = function_ref @$s16_Differentiation7_vjpCosySf5value_S2fc8pullbacktSfFS2fcfU_ : $@convention(thin) (Float, Float) -> Float // user: %15
  %15 = partial_apply [callee_guaranteed] %14(%1) : $@convention(thin) (Float, Float) -> Float // user: %18
  // function_ref closure #1 in static Float._vjpMultiply(lhs:rhs:)
  %16 = function_ref @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float, Float, Float) -> (Float, Float) // user: %17
  %17 = partial_apply [callee_guaranteed] %16(%13, %9) : $@convention(thin) (Float, Float, Float) -> (Float, Float) // user: %18
  %18 = tuple $(predecessor: _AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, Float)) (%7, %11, %15, %17) // user: %19
  %19 = enum $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0, #_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0.bb1!enumelt, %18 : $(predecessor: _AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, Float)) // user: %20
  br bb3(%19 : $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0) // id: %20

bb2:                                              // Preds: bb0
  %21 = enum $_AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0, #_AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0.bb0!enumelt, %5 : $() // user: %28
  // function_ref closure #1 in _vjpSin(_:)
  %22 = function_ref @$s16_Differentiation7_vjpSinySf5value_S2fc8pullbacktSfFS2fcfU_ : $@convention(thin) (Float, Float) -> Float // user: %23
  %23 = partial_apply [callee_guaranteed] %22(%1) : $@convention(thin) (Float, Float) -> Float // user: %28
  // function_ref closure #1 in _vjpCos(_:)
  %24 = function_ref @$s16_Differentiation7_vjpCosySf5value_S2fc8pullbacktSfFS2fcfU_ : $@convention(thin) (Float, Float) -> Float // user: %25
  %25 = partial_apply [callee_guaranteed] %24(%1) : $@convention(thin) (Float, Float) -> Float // user: %28
  // function_ref closure #1 in static Float._vjpAdd(lhs:rhs:)
  %26 = function_ref @$sSf16_DifferentiationE7_vjpAdd3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float) -> (Float, Float) // user: %27
  %27 = thin_to_thick_function %26 : $@convention(thin) (Float) -> (Float, Float) to $@callee_guaranteed (Float) -> (Float, Float) // user: %28
  %28 = tuple $(predecessor: _AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, Float)) (%21, %23, %25, %27) // user: %29
  %29 = enum $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0, #_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0.bb2!enumelt, %28 : $(predecessor: _AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, Float)) // user: %30
  br bb3(%29 : $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0) // id: %30

// %31                                            // user: %36
bb3(%31 : $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0): // Preds: bb1 bb2
  // function_ref pullback of f(_:)
  %32 = function_ref @$s6sincos1fyS2fFTJpSpSr : $@convention(thin) (Float, @owned _AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0) -> Float // user: %36
  %33 = integer_literal $Builtin.Int64, 1         // user: %34
  %34 = builtin "sitofp_Int64_FPIEEE32"(%33 : $Builtin.Int64) : $Builtin.FPIEEE32 // user: %35
  %35 = struct $Float (%34 : $Builtin.FPIEEE32)   // user: %36
  %36 = apply %32(%35, %31) : $@convention(thin) (Float, @owned _AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0) -> Float // user: %37
  return %36 : $Float                             // id: %37
} // end sil function '$s6sincos3fooSfyF'

// pullback of f(_:)
sil private @$s6sincos1fyS2fFTJpSpSr : $@convention(thin) (Float, @owned _AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0) -> Float {
[%0: read v**.c*.v**, write v**.c*.v**, copy v**.c*.v**, destroy v**.c*.v**]
[%1: noescape v**, read v**.c*.v**, write v**.c*.v**, copy v**.c*.v**, destroy v**.c*.v**]
[global: read,write,copy,destroy,allocate,deinit_barrier]
// %0                                             // users: %35, %10
// %1                                             // user: %5
bb0(%0 : $Float, %1 : $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0):
  %2 = integer_literal $Builtin.Int64, 0          // user: %3
  %3 = builtin "sitofp_Int64_FPIEEE32"(%2 : $Builtin.Int64) : $Builtin.FPIEEE32 // users: %43, %40, %18, %15, %4, %48, %23
  debug_value %3 : $Builtin.FPIEEE32, let, name "x", argno 1, type $Float, expr op_fragment:#Float._value // id: %4
  switch_enum %1 : $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0, case #_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0.bb2!enumelt: bb1, case #_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0.bb1!enumelt: bb2 // id: %5

// %6                                             // users: %9, %8, %7
bb1(%6 : $(predecessor: _AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, Float))): // Preds: bb0
  %7 = tuple_extract %6 : $(predecessor: _AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, Float)), 1 // users: %25, %24
  %8 = tuple_extract %6 : $(predecessor: _AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, Float)), 2 // users: %21, %20
  %9 = tuple_extract %6 : $(predecessor: _AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, Float)), 3 // users: %11, %10
  %10 = apply %9(%0) : $@callee_guaranteed (Float) -> (Float, Float) // users: %13, %12
  strong_release %9 : $@callee_guaranteed (Float) -> (Float, Float) // id: %11
  %12 = tuple_extract %10 : $(Float, Float), 0    // user: %14
  %13 = tuple_extract %10 : $(Float, Float), 1    // user: %17
  %14 = struct_extract %12 : $Float, #Float._value // user: %15
  %15 = builtin "fadd_FPIEEE32"(%14 : $Builtin.FPIEEE32, %3 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %16
  %16 = struct $Float (%15 : $Builtin.FPIEEE32)   // user: %24
  %17 = struct_extract %13 : $Float, #Float._value // user: %18
  %18 = builtin "fadd_FPIEEE32"(%17 : $Builtin.FPIEEE32, %3 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %19
  %19 = struct $Float (%18 : $Builtin.FPIEEE32)   // user: %20
  %20 = apply %8(%19) : $@callee_guaranteed (Float) -> Float // user: %22
  strong_release %8 : $@callee_guaranteed (Float) -> Float // id: %21
  %22 = struct_extract %20 : $Float, #Float._value // user: %23
  %23 = builtin "fadd_FPIEEE32"(%22 : $Builtin.FPIEEE32, %3 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %27
  %24 = apply %7(%16) : $@callee_guaranteed (Float) -> Float // user: %26
  strong_release %7 : $@callee_guaranteed (Float) -> Float // id: %25
  %26 = struct_extract %24 : $Float, #Float._value // user: %27
  %27 = builtin "fadd_FPIEEE32"(%26 : $Builtin.FPIEEE32, %23 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %28
  %28 = struct $Float (%27 : $Builtin.FPIEEE32)   // users: %30, %29
  debug_value %28 : $Float, let, name "x", argno 1 // id: %29
  br bb3(%28 : $Float)                            // id: %30

// %31                                            // users: %34, %33, %32
bb2(%31 : $(predecessor: _AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, Float))): // Preds: bb0
  %32 = tuple_extract %31 : $(predecessor: _AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, Float)), 1 // users: %50, %49
  %33 = tuple_extract %31 : $(predecessor: _AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, Float)), 2 // users: %46, %45
  %34 = tuple_extract %31 : $(predecessor: _AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, Float)), 3 // users: %36, %35
  %35 = apply %34(%0) : $@callee_guaranteed (Float) -> (Float, Float) // users: %38, %37
  strong_release %34 : $@callee_guaranteed (Float) -> (Float, Float) // id: %36
  %37 = tuple_extract %35 : $(Float, Float), 0    // user: %39
  %38 = tuple_extract %35 : $(Float, Float), 1    // user: %42
  %39 = struct_extract %37 : $Float, #Float._value // user: %40
  %40 = builtin "fadd_FPIEEE32"(%39 : $Builtin.FPIEEE32, %3 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %41
  %41 = struct $Float (%40 : $Builtin.FPIEEE32)   // user: %49
  %42 = struct_extract %38 : $Float, #Float._value // user: %43
  %43 = builtin "fadd_FPIEEE32"(%42 : $Builtin.FPIEEE32, %3 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %44
  %44 = struct $Float (%43 : $Builtin.FPIEEE32)   // user: %45
  %45 = apply %33(%44) : $@callee_guaranteed (Float) -> Float // user: %47
  strong_release %33 : $@callee_guaranteed (Float) -> Float // id: %46
  %47 = struct_extract %45 : $Float, #Float._value // user: %48
  %48 = builtin "fadd_FPIEEE32"(%47 : $Builtin.FPIEEE32, %3 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %52
  %49 = apply %32(%41) : $@callee_guaranteed (Float) -> Float // user: %51
  strong_release %32 : $@callee_guaranteed (Float) -> Float // id: %50
  %51 = struct_extract %49 : $Float, #Float._value // user: %52
  %52 = builtin "fadd_FPIEEE32"(%51 : $Builtin.FPIEEE32, %48 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %53
  %53 = struct $Float (%52 : $Builtin.FPIEEE32)   // users: %55, %54
  debug_value %53 : $Float, let, name "x", argno 1 // id: %54
  br bb3(%53 : $Float)                            // id: %55

// %56                                            // users: %58, %57
bb3(%56 : $Float):                                // Preds: bb1 bb2
  debug_value %56 : $Float, let, name "x", argno 1 // id: %57
  return %56 : $Float                             // id: %58
} // end sil function '$s6sincos1fyS2fFTJpSpSr'

It would be great to find ways to simplify this further.

One immediate thing is to improve constant propagation: obviously the condition in the function is already true (though, this optimization is done further on LLVM IR level, however, we still have 3 closures there and therefore – 3 context allocations for).

Though, this certainly won't help the common case:

@inline(never)
func foo(_ x : Float) -> Float {
  gradient(at: x, of: f)
}

It looks like the apply of pullback closure was correctly turned into apply of the pullback itself. However, it was not inlined further. It would be interesting to understand, why. And even if the pullback would be inlined, can we do some kind of CSE and optimize the diamond-shape CFG hoisting the code into the corresponding BBs?

asl commented 1 year ago

Tagging @rxwei @BradLarson @jkshtj

Any ideas, etc. would be welcome

asl commented 1 year ago

Things to check immediately:

  1. Usually the function that returns a closure (partial_apply) gets additional inlining bonus in the inlining cost / benefit calculation. VJPs return tuple of original result and a pullback closure: do they get this additional bonus. And even if yes – do we want to increase this bonus for VJPs?
  2. Similar thing but for pullback – IIRC there is an additional inlining cost / benefit bonus if the function takes a closure as an argument. However, in case of control flow differentiation, these closures are hidden beneath the enum payload. Do we handle this case correctly? Do pullback receive this bonus? Do we want to increase inlining bonus for pullbacks?

And now the most interesting stuff wrt the code above. Here we're seeing that VJPs were fully inlined and therefore we are seeing the very final partial_apply that are further stored in pullback tuple and passed to pullback. What we can do here is to rewrite linear map enum / linear map tuple: if we know that we are going to store partial_apply there, then store the value it is closed upon. Then we can move partial_apply to the corresponding BB of the pullback opening full inlining opportunities as there is an apply of that partial_apply there.

I am seeing two cases here:

  1. Pullback if fully inlined. Fine, we're operating inside a single function
  2. Pullback is not fully inlined. Then we need to "specialize" the pullback, cloning the pullback itsels, changing its signature and change the corresponding code.

Additionally, we might have linear functions that are captured as-is, without partial_apply (like closure from vjpAdd above), then we can move them into pullback directly, no need to capture the value at all.

Certainly, we'd need to handle the cases when not all VJPs are inlined. However, I believe this operation would be most beneficial for "leaf" pullbacks, as they are plenty, usually represent elementary operations (like addition, multiplication, etc.) or custom-registered derivatives.

So, for the case above we'd turn:

// foo(_:)
sil hidden [noinline] @$s6sincos3fooyS2fF : $@convention(thin) (Float) -> Float {
[global: read,write,copy,destroy,allocate,deinit_barrier]
// %0 "x"                                         // users: %26, %24, %16, %12, %3, %2, %1
bb0(%0 : $Float):
  debug_value %0 : $Float, let, name "x", argno 1 // id: %1
  debug_value %0 : $Float, let, name "x", argno 1 // id: %2
  %3 = struct_extract %0 : $Float, #Float._value  // users: %13, %9, %5
  %4 = float_literal $Builtin.FPIEEE32, 0x0 // 0  // user: %5
  %5 = builtin "fcmp_olt_FPIEEE32"(%4 : $Builtin.FPIEEE32, %3 : $Builtin.FPIEEE32) : $Builtin.Int1 // user: %7
  %6 = tuple ()                                   // users: %22, %8
  cond_br %5, bb1, bb2                            // id: %7

bb1:                                              // Preds: bb0
  %8 = enum $_AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0, #_AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0.bb0!enumelt, %6 : $() // user: %19
  %9 = builtin "int_sin_FPIEEE32"(%3 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %10
  %10 = struct $Float (%9 : $Builtin.FPIEEE32)    // user: %18
  // function_ref closure #1 in _vjpSin(_:)
  %11 = function_ref @$s16_Differentiation7_vjpSinySf5value_S2fc8pullbacktSfFS2fcfU_ : $@convention(thin) (Float, Float) -> Float // user: %12
  %12 = partial_apply [callee_guaranteed] %11(%0) : $@convention(thin) (Float, Float) -> Float // user: %19
  %13 = builtin "int_cos_FPIEEE32"(%3 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %14
  %14 = struct $Float (%13 : $Builtin.FPIEEE32)   // user: %18
  // function_ref closure #1 in _vjpCos(_:)
  %15 = function_ref @$s16_Differentiation7_vjpCosySf5value_S2fc8pullbacktSfFS2fcfU_ : $@convention(thin) (Float, Float) -> Float // user: %16
  %16 = partial_apply [callee_guaranteed] %15(%0) : $@convention(thin) (Float, Float) -> Float // user: %19
  // function_ref closure #1 in static Float._vjpMultiply(lhs:rhs:)
  %17 = function_ref @$sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float, Float, Float) -> (Float, Float) // user: %18
  %18 = partial_apply [callee_guaranteed] %17(%14, %10) : $@convention(thin) (Float, Float, Float) -> (Float, Float) // user: %19
  %19 = tuple $(predecessor: _AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, Float)) (%8, %12, %16, %18) // user: %20
  %20 = enum $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0, #_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0.bb1!enumelt, %19 : $(predecessor: _AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, Float)) // user: %21
  br bb3(%20 : $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0) // id: %21

bb2:                                              // Preds: bb0
  %22 = enum $_AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0, #_AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0.bb0!enumelt, %6 : $() // user: %29
  // function_ref closure #1 in _vjpSin(_:)
  %23 = function_ref @$s16_Differentiation7_vjpSinySf5value_S2fc8pullbacktSfFS2fcfU_ : $@convention(thin) (Float, Float) -> Float // user: %24
  %24 = partial_apply [callee_guaranteed] %23(%0) : $@convention(thin) (Float, Float) -> Float // user: %29
  // function_ref closure #1 in _vjpCos(_:)
  %25 = function_ref @$s16_Differentiation7_vjpCosySf5value_S2fc8pullbacktSfFS2fcfU_ : $@convention(thin) (Float, Float) -> Float // user: %26
  %26 = partial_apply [callee_guaranteed] %25(%0) : $@convention(thin) (Float, Float) -> Float // user: %29
  // function_ref closure #1 in static Float._vjpAdd(lhs:rhs:)
  %27 = function_ref @$sSf16_DifferentiationE7_vjpAdd3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_ : $@convention(thin) (Float) -> (Float, Float) // user: %28
  %28 = thin_to_thick_function %27 : $@convention(thin) (Float) -> (Float, Float) to $@callee_guaranteed (Float) -> (Float, Float) // user: %29
  %29 = tuple $(predecessor: _AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, Float)) (%22, %24, %26, %28) // user: %30
  %30 = enum $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0, #_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0.bb2!enumelt, %29 : $(predecessor: _AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> Float, @callee_guaranteed (Float) -> (Float, Float)) // user: %31
  br bb3(%30 : $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0) // id: %31

// %32                                            // user: %37
bb3(%32 : $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0): // Preds: bb1 bb2
  // function_ref pullback of f(_:)
  %33 = function_ref @$s6sincos1fyS2fFTJpSpSr : $@convention(thin) (Float, @owned _AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0) -> Float // user: %37
  %34 = integer_literal $Builtin.Int64, 1         // user: %35
  %35 = builtin "sitofp_Int64_FPIEEE32"(%34 : $Builtin.Int64) : $Builtin.FPIEEE32 // user: %36
  %36 = struct $Float (%35 : $Builtin.FPIEEE32)   // user: %37
  %37 = apply %33(%36, %32) : $@convention(thin) (Float, @owned _AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0) -> Float // user: %38
  return %37 : $Float                             // id: %38
} // end sil function '$s6sincos3fooyS2fF'

into:


enum _AD__$s6sincos1fyS2fF_bb0__Pred__src_0_wrt_0 {
}

enum _AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0 {
  case bb0(())
}

enum _AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0 {
  case bb0(())
}

enum _AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0 {
  case bb2((predecessor: _AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0, Float, Float, (Float, Float)))
  case bb1((predecessor: _AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0, Float, Float))
}

// foo(_:)
sil hidden [noinline] @$s6sincos3fooyS2fF : $@convention(thin) (Float) -> Float {
[global: read,write,copy,destroy,allocate,deinit_barrier]
// %0 "x"                                         // users: %26, %24, %16, %12, %3, %2, %1
bb0(%0 : $Float):
  debug_value %0 : $Float, let, name "x", argno 1 // id: %1
  debug_value %0 : $Float, let, name "x", argno 1 // id: %2
  %3 = struct_extract %0 : $Float, #Float._value  // users: %13, %9, %5
  %4 = float_literal $Builtin.FPIEEE32, 0x0 // 0  // user: %5
  %5 = builtin "fcmp_olt_FPIEEE32"(%4 : $Builtin.FPIEEE32, %3 : $Builtin.FPIEEE32) : $Builtin.Int1 // user: %7
  %6 = tuple ()                                   // users: %22, %8
  cond_br %5, bb1, bb2                            // id: %7

bb1:                                              // Preds: bb0
  %8 = enum $_AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0, #_AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0.bb0!enumelt, %6 : $() // user: %19
  %9 = builtin "int_sin_FPIEEE32"(%3 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %10
  %10 = struct $Float (%9 : $Builtin.FPIEEE32)    // user: %18
  %13 = builtin "int_cos_FPIEEE32"(%3 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %14
  %14 = struct $Float (%13 : $Builtin.FPIEEE32)
  %newt = tuple $(Float, Float) (%14, %10)
  %19 = tuple $(predecessor: _AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0,  Float, Float, (Float, Float)) (%8, %0, %0, %newt) // user: %20
  %20 = enum $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0, #_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0.bb1!enumelt, %19 : $(predecessor: _AD__$s6sincos1fyS2fF_bb1__Pred__src_0_wrt_0, Float, (Float, Float)) // user: %21
  br bb3(%20 : $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0) // id: %21

bb2:                                              // Preds: bb0
  %22 = enum $_AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0, #_AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0.bb0!enumelt, %6 : $() // user: %29
  %29 = tuple $(predecessor: _AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0, Float, Float) (%22, %0, %0) // user: %30
  %30 = enum $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0, #_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0.bb2!enumelt, %29 : $(predecessor: _AD__$s6sincos1fyS2fF_bb2__Pred__src_0_wrt_0, Float, Float) // user: %31
  br bb3(%30 : $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0) // id: %31

// %32                                            // user: %37
bb3(%32 : $_AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0): // Preds: bb1 bb2
  // function_ref pullback of f(_:)
  %33 = function_ref @$s6sincos1fyS2fFTJpSpSr : $@convention(thin) (Float, @owned _AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0) -> Float // user: %37
  %34 = integer_literal $Builtin.Int64, 1         // user: %35
  %35 = builtin "sitofp_Int64_FPIEEE32"(%34 : $Builtin.Int64) : $Builtin.FPIEEE32 // user: %36
  %36 = struct $Float (%35 : $Builtin.FPIEEE32)   // user: %37
  %37 = apply %33(%36, %32) : $@convention(thin) (Float, @owned _AD__$s6sincos1fyS2fF_bb3__Pred__src_0_wrt_0) -> Float // user: %38
  return %37 : $Float                             // id: %38
} // end sil function '$s6sincos3fooyS2fF'

Thoughts?

jkshtj commented 1 year ago

~Based on discussions w/ Anton, there's 4 high level tasks that we can look into (listed in order of priority)~ ~1. Inlining cost model of functions receiving closures.~ ~ What is it? What do we need to tweak so that simple pullbacks without any control-flow can get fully inlined?~ ~2. Inlining cost model of functions returning closures.~ ~ What is it? Is it sufficient for optimizing most cases or does it need to be tweaked to unlock more optimizations.~ ~* Based on the example here it seems like it does optimize VJPs (which return the pullback closure) correctly.~ ~3. *Reducing overall closure context allocations in VJPs.~ ~ pullbacks with control flow, where the received closure arguments are hidden behind enums/tuples.~ ~ Instead of passing in~ ~ Reducing overall closure context allocations in pullback w/ loops might be similar but more tricky.~ ~4. Interprocedural constant propagation to VJPs from differential operators.**~

Look below for list of subtasks.

asl commented 1 year ago

@jkshtj 3. is not about inlining really. And this particular issue is not about inlining. The proposed optimization indeed enables some inlining opportunities, but again, it's not about inlining, but rather changing of the code layout – instead of capturing closures capture values being closed.

asl commented 1 year ago

The real subtasks are:

  1. Check the inlining cost / benefit model for autodiff-generated functions (we need to ensure they receive benefit bonus):

    • VJPs (return tuple of value + pullback closure)
    • Pullbacks (may receive closures as arguments or may receive linear map tuple / enum in case of control flow differentiation). The linear map tuple case is likely not properly handled.
  2. Implement the closure optimization that is specialized towards the linear map tuples / enums produced by autodiff. In particular, as in example above, if we see the particular closure (partial_apply), then instead of storing the closure in the tuple, store the closed value. And move partial_apply down to the apply site (no need to fold, there are existing passes to do this). And we know the place of use due to the way how the linear map tuples and branch tracing enums are generated.

jkshtj commented 1 year ago

One immediate thing is to improve constant propagation: obviously the condition in the function is already true (though, this optimization is done further on LLVM IR level, however, we still have 3 closures there and therefore – 3 context allocations for).

@asl I compiled non-autodiff version of our example here and it looks like the compiler does not eliminate conditional branches even if the value of the condition is known. The code below -

func f(_ x: Float) -> Float {
  if (x > 0) {
    return x*x
  } else {
    return x+x
  }
}

@inline(never)
func foo() -> Float {
  f(4)
}

Compiles to this -

// foo()
sil hidden [noinline] @$s4test3fooSfyF : $@convention(thin) () -> Float {
[global: ]
bb0:
  %0 = float_literal $Builtin.FPIEEE32, 0x40800000 // 4 // users: %1, %3
  debug_value %0 : $Builtin.FPIEEE32, let, name "x", argno 1, type $Float, expr op_fragment:#Float._value // id: %1
  %2 = float_literal $Builtin.FPIEEE32, 0x0 // 0  // user: %3
  %3 = builtin "fcmp_olt_FPIEEE32"(%2 : $Builtin.FPIEEE32, %0 : $Builtin.FPIEEE32) : $Builtin.Int1 // user: %4
  cond_br %3, bb1, bb2                            // id: %4

bb1:                                              // Preds: bb0
  %5 = float_literal $Builtin.FPIEEE32, 0x41800000 // 16 // user: %6
  %6 = struct $Float (%5 : $Builtin.FPIEEE32)     // user: %7
  br bb3(%6 : $Float)                             // id: %7

bb2:                                              // Preds: bb0
  %8 = float_literal $Builtin.FPIEEE32, 0x41000000 // 8 // user: %9
  %9 = struct $Float (%8 : $Builtin.FPIEEE32)     // user: %10
  br bb3(%9 : $Float)                             // id: %10

// %11                                            // user: %12
bb3(%11 : $Float):                                // Preds: bb2 bb1
  return %11 : $Float                             // id: %12
} // end sil function '$s4test3fooSfyF'

I mostly looked through the SimplifyCfg optimization pass to see if something is just blocking this optimization, but it looks like it's not implemented at all.

eeckstein commented 1 year ago

Looks like that the constant folder doesn't handle float point comparisons. Reproducible with

func f() -> Bool {
  let x: Float = 1.0
  return x > 0
}

Even when compiled with -O the generated SIL still contains the not folded fcmp_olt builtin.

It should be easy to add this to constantFoldCompare https://github.com/apple/swift/blob/3eea9d680d334841c8831fab87c1b2fe3add95d2/lib/SILOptimizer/Utils/ConstantFolding.cpp#L354

Is someone motivated to add this?

jkshtj commented 1 year ago

@eeckstein

Thanks for pointing that out! Just by looking at SimplifyCFG I could not understand how "insufficient" constant-propagation was blocking the optimization, but dumping the SIL around opt passes shows, that's what's happening. I'll read the code a bit more to see where exactly in simply-cfg this optimization is fully realized.

Is someone motivated to add this?

Are you asking for more use-cases (in addition to the one in this example)? Or are you asking if someone would be motivated to do the work? In which case I can definitely go ahead and do it.

eeckstein commented 1 year ago

In which case I can definitely go ahead and do it.

That would be great!