swiftlang / swift

The Swift Programming Language
https://swift.org
Apache License 2.0
67.6k stars 10.37k forks source link

ClosureLifetimeFixup inhibits optimizations #77651

Open asl opened 4 hours ago

asl commented 4 hours ago

Description

I saw this with autodiff-produced code, but I believe the issue is generic and is not autodiff-related, it was just easier to notice there.

Consider the reproduction code. hardyCrossSimpleLoop function contains a loop that calculates gradient in the loop. If we'd comment out let (f, Df) = valueWithGradient(at: flow, of: subsystem.totalDeltaP) line and instead uncomment the following line that does essentially the same, then the code is optimized down to just simple sequence of arithmetic operations. However, the original code cannot be optimized down and actually the runtime differs by something like 3 orders of magnitude.

The loop body looks as follows:

bb2(%60 : $Int):                                  // Preds: bb1
  %61 = alloc_stack $Float                        // users: %109, %102, %93
  %62 = alloc_stack $Float                        // users: %108, %103, %93
  %63 = begin_access [read] [unknown] %24 : $*Float // users: %65, %64
  %64 = load [trivial] %63 : $*Float              // user: %67
  end_access %63 : $*Float                        // id: %65
  %66 = alloc_stack $Float                        // users: %101, %93, %67
  store %64 to [trivial] %66 : $*Float            // id: %67
  // function_ref implicit closure #1 in hardyCrossSimpleLoop(initialFlow:subsystem:)
  %68 = function_ref @$s4flow20hardyCrossSimpleLoop33_D8A55993C73D177076A5BC68755B3D90LL11initialFlow9subsystemS2f_AA9SubsystemACLLVtFS2fcAGcfu_ : $@convention(thin) (Subsystem) -> @owned @callee_guaranteed (Float) -> Float // user: %69
  %69 = apply %68(%1) : $@convention(thin) (Subsystem) -> @owned @callee_guaranteed (Float) -> Float // user: %70
  %70 = differentiable_function [parameters 0] [results 0] %69 : $@callee_guaranteed (Float) -> Float // users: %100, %71
  %71 = convert_escape_to_noescape [not_guaranteed] %70 : $@differentiable(reverse) @callee_guaranteed (Float) -> Float to $@differentiable(reverse) @noescape @callee_guaranteed (Float) -> Float // users: %99, %72
  %72 = begin_borrow %71 : $@differentiable(reverse) @noescape @callee_guaranteed (Float) -> Float // users: %98, %85, %79, %73
  %73 = differentiable_function_extract [original] %72 : $@differentiable(reverse) @noescape @callee_guaranteed (Float) -> Float // user: %74
  %74 = copy_value %73 : $@noescape @callee_guaranteed (Float) -> Float // user: %76
  // function_ref thunk for @callee_guaranteed (@unowned Float) -> (@unowned Float)
  %75 = function_ref @$sS2fIgyd_S2fIegnr_TR : $@convention(thin) (@in_guaranteed Float, @guaranteed @noescape @callee_guaranteed (Float) -> Float) -> @out Float // user: %76
  %76 = partial_apply [callee_guaranteed] %75(%74) : $@convention(thin) (@in_guaranteed Float, @guaranteed @noescape @callee_guaranteed (Float) -> Float) -> @out Float // user: %77
  %77 = convert_function %76 : $@callee_guaranteed (@in_guaranteed Float) -> @out Float to $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Float, Float> // users: %97, %78
  %78 = convert_escape_to_noescape [not_guaranteed] %77 : $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Float, Float> to $@noescape @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Float, Float> // user: %91
  %79 = differentiable_function_extract [jvp] %72 : $@differentiable(reverse) @noescape @callee_guaranteed (Float) -> Float // user: %80
  %80 = copy_value %79 : $@noescape @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) // user: %82
  // function_ref thunk for @callee_guaranteed (@unowned Float) -> (@unowned Float, @owned @escaping @callee_guaranteed (@unowned Float) -> (@unowned Float))
  %81 = function_ref @$sS4fIegyd_Igydo_S2fxq_Ri_zRi0_zRi__Ri0__r0_lyS2fIsegnr_Iegnro_TR : $@convention(thin) (@in_guaranteed Float, @guaranteed @noescape @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) -> (@out Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Float, Float>) // user: %82
  %82 = partial_apply [callee_guaranteed] %81(%80) : $@convention(thin) (@in_guaranteed Float, @guaranteed @noescape @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) -> (@out Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Float, Float>) // user: %83
  %83 = convert_function %82 : $@callee_guaranteed (@in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Float, Float>) to $@callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2, τ_0_3> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_2, τ_0_3>) for <Float, Float, Float, Float> // users: %96, %84
...

So, essentially we are taking a closure over subsystem value and then executing the autodiff code. Note that all values involved are explicitly destroyed in the end of the loop body. Differentiation does not change much, essentially the first apply is just changed to autodiff curry thunk.

What happes is ClosureLifetimeFixup essentially creates bunch of optionals extending the lifetime of all closures until the end of the next loop iteration:

bb0(%0 : $Float, %1 : $Subsystem):
  %2 = enum $Optional<@callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2, τ_0_3> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_2, τ_0_3>) for <Float, Float, Float, Float>>, #Optional.none!enumelt // users: %55, %3
  %3 = begin_borrow %2 : $Optional<@callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2, τ_0_3> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_2, τ_0_3>) for <Float, Float, Float, Float>> // user: %55
  %4 = enum $Optional<@callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2, τ_0_3> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_2, τ_0_3>) for <Float, Float, Float, Float>>, #Optional.none!enumelt // users: %55, %5
  %5 = begin_borrow %4 : $Optional<@callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2, τ_0_3> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_2, τ_0_3>) for <Float, Float, Float, Float>> // user: %55
  %6 = enum $Optional<@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Float, Float>>, #Optional.none!enumelt // users: %55, %7
  %7 = begin_borrow %6 : $Optional<@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Float, Float>> // user: %55
  %8 = enum $Optional<@differentiable(reverse) @callee_guaranteed (Float) -> Float>, #Optional.none!enumelt // users: %55, %9
  %9 = begin_borrow %8 : $Optional<@differentiable(reverse) @callee_guaranteed (Float) -> Float> // user: %55
  debug_value %0 : $Float, let, name "initialFlow", argno 1 // id: %10
...
bb1(%56 : @reborrow $Optional<@differentiable(reverse) @callee_guaranteed (Float) -> Float>, %57 : @owned $Optional<@differentiable(reverse) @callee_guaranteed (Float) -> Float>, %58 : @reborrow $Optional<@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Float, Float>>, %59 : @owned $Optional<@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Float, Float>>, %60 : @reborrow $Optional<@callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2, τ_0_3> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_2, τ_0_3>) for <Float, Float, Float, Float>>, %61 : @owned $Optional<@callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2, τ_0_3> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_2, τ_0_3>) for <Float, Float, Float, Float>>, %62 : @reborrow $Optional<@callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2, τ_0_3> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_2, τ_0_3>) for <Float, Float, Float, Float>>, %63 : @owned $Optional<@callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2, τ_0_3> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_2, τ_0_3>) for <Float, Float, Float, Float>>): // Preds: bb4 bb0
  %64 = borrowed %62 : $Optional<@callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2, τ_0_3> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_2, τ_0_3>) for <Float, Float, Float, Float>> from (%63 : $Optional<@callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2, τ_0_3> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_2, τ_0_3>) for <Float, Float, Float, Float>>) // users: %215, %127
  %65 = borrowed %60 : $Optional<@callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2, τ_0_3> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_2, τ_0_3>) for <Float, Float, Float, Float>> from (%61 : $Optional<@callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2, τ_0_3> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_2, τ_0_3>) for <Float, Float, Float, Float>>) // users: %215, %115
  %66 = borrowed %58 : $Optional<@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Float, Float>> from (%59 : $Optional<@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Float, Float>>) // users: %215, %103
  %67 = borrowed %56 : $Optional<@differentiable(reverse) @callee_guaranteed (Float) -> Float> from (%57 : $Optional<@differentiable(reverse) @callee_guaranteed (Float) -> Float>) // users: %215, %90
  %68 = alloc_stack $Optional<Int>                // users: %74, %73, %71
  %69 = begin_access [modify] [static] %32 : $*IndexingIterator<Range<Int>> // users: %72, %71
  // function_ref IndexingIterator.next()
  %70 = function_ref @$ss16IndexingIteratorV4next7ElementQzSgyF : $@convention(method) <τ_0_0 where τ_0_0 : Collection> (@inout IndexingIterator<τ_0_0>) -> @out Optional<τ_0_0.Element> // user: %71
...

Looks like these optionals inhibit all kinds of optimizations including inlining, specialization, as the closures and differentiable functions now have multiple uses and therefore necessary peepholes (that expects single use obviously) cannot happen. So, after all optimizations we end with:

// %34                                            // user: %54
// %35                                            // user: %59
// %36                                            // user: %62
// %37                                            // user: %66
// %38                                            // users: %89, %69
// %39                                            // users: %45, %41, %40
bb2(%34 : $Optional<@differentiable(reverse) @callee_guaranteed (Float) -> Float>, %35 : $Optional<@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Float, Float>>, %36 : $Optional<@callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2, τ_0_3> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_2, τ_0_3>) for <Float, Float, Float, Float>>, %37 : $Optional<@callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2, τ_0_3> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_2, τ_0_3>) for <Float, Float, Float, Float>>, %38 : $Builtin.FPIEEE32, %39 : $Builtin.Int64): // Preds: bb5 bb0
  %40 = builtin "cmp_slt_Int64"(%39 : $Builtin.Int64, %14 : $Builtin.Int64) : $Builtin.Int1 // user: %43
  %41 = builtin "cmp_slt_Int64"(%39 : $Builtin.Int64, %7 : $Builtin.Int64) : $Builtin.Int1 // user: %42
  %42 = builtin "xor_Int1"(%41 : $Builtin.Int1, %18 : $Builtin.Int1) : $Builtin.Int1 // user: %43
  %43 = builtin "or_Int1"(%40 : $Builtin.Int1, %42 : $Builtin.Int1) : $Builtin.Int1 // user: %44
  cond_fail %43 : $Builtin.Int1, "Index out of bounds" // id: %44
  %45 = builtin "sadd_with_overflow_Int64"(%39 : $Builtin.Int64, %19 : $Builtin.Int64, %18 : $Builtin.Int1) : $(Builtin.Int64, Builtin.Int1) // users: %47, %46
  %46 = tuple_extract %45 : $(Builtin.Int64, Builtin.Int1), 0 // users: %105, %97, %49
  %47 = tuple_extract %45 : $(Builtin.Int64, Builtin.Int1), 1 // user: %48
  cond_fail %47 : $Builtin.Int1, "arithmetic overflow" // id: %48
  debug_value %46 : $Builtin.Int64, var, name "$generator", type $IndexingIterator<Range<Int>>, expr op_fragment:#IndexingIterator._position:op_fragment:#Int._value // id: %49
  %50 = alloc_stack $Float                        // users: %80, %83, %85
  %51 = partial_apply [callee_guaranteed] %22(%1) : $@convention(thin) (Float, Subsystem) -> Float // users: %106, %109, %31, %53
  %52 = partial_apply [callee_guaranteed] %23(%1) : $@convention(thin) (Float, Subsystem) -> (Float, @owned @callee_guaranteed (Float) -> Float) // users: %107, %110, %32, %53
  %53 = differentiable_function [parameters 0] [results 0] %51 : $@callee_guaranteed (Float) -> Float with_derivative {%29 : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), %52 : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)} // users: %56, %55, %99
  release_value %34 : $Optional<@differentiable(reverse) @callee_guaranteed (Float) -> Float> // id: %54
  %55 = convert_escape_to_noescape %53 : $@differentiable(reverse) @callee_guaranteed (Float) -> Float to $@differentiable(reverse) @noescape @callee_guaranteed (Float) -> Float // user: %56
  %56 = mark_dependence %55 : $@differentiable(reverse) @noescape @callee_guaranteed (Float) -> Float on %53 : $@differentiable(reverse) @callee_guaranteed (Float) -> Float // users: %57, %60, %63
  %57 = differentiable_function_extract [original] %56 : $@differentiable(reverse) @noescape @callee_guaranteed (Float) -> Float // user: %58
  %58 = partial_apply [callee_guaranteed] %24(%57) : $@convention(thin) (@in_guaranteed Float, @guaranteed @noescape @callee_guaranteed (Float) -> Float) -> @out Float // users: %112, %100
  release_value %35 : $Optional<@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Float, Float>> // id: %59
  %60 = differentiable_function_extract [jvp] %56 : $@differentiable(reverse) @noescape @callee_guaranteed (Float) -> Float // user: %61
  %61 = partial_apply [callee_guaranteed] %25(%60) : $@convention(thin) (@in_guaranteed Float, @guaranteed @noescape @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) -> (@out Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Float, Float>) // users: %113, %102
...

(note that %53 has also %99 use that essentially saves the function into optional until the next iteration).

All these does not happen when we control the lifetime just outlining the things into a separate function.

Is there something that could be done differently at autodiff side to "help" closure fixup if it cannot be fixed to handle this particular case in better way?

Reproduction

import _Differentiation

// MARK: Boiler

private typealias Scalar = Float

private struct Boiler {
    @differentiable(reverse)
    func inlineResistanceFluid(_ x: Scalar) -> Scalar {
        100000.0
    }

    @differentiable(reverse)
    func deltaP(_ x: Scalar) -> Scalar {
        -inlineResistanceFluid(x) * x * x
    }
}

// MARK: Pump

private struct Pump {
    private let coefConst: Scalar = 54162.06
    private let coefLin: Scalar = -51815.13
    private let coefQuad: Scalar = 10052.31

    @differentiable(reverse)
    //@_silgen_name("Pump_deltaP")
    func deltaP(_ x: Scalar) -> Scalar {
        coefConst + coefLin * x + coefQuad * x * x
    }
}

// MARK: RadiantSlab

private struct RadiantSlab {
    init() { }
    //@_silgen_name("RadiantSlab_inlineResistanceFluid")
    @differentiable(reverse)
    func inlineResistanceFluid(_ x: Scalar) -> Scalar {
        let dynamicViscosity: Scalar = 8.9e-4
        let nLoop = Scalar(3)
        let lengthTotal: Scalar = 163.98
        let innerDia: Scalar = 0.01905
        let lengthLoop = lengthTotal / nLoop
        let areaFlow: Scalar = Scalar.pi * 0.25 * innerDia * innerDia

        let reNr: Scalar = x * innerDia / (dynamicViscosity * areaFlow) // f(x)
        let frictionFactor: Scalar = 64.0 / (reNr + 0.01) //h(g(f(x)))
        let flowCoeff: Scalar = frictionFactor * lengthLoop / innerDia //i(h(g(f(x))))
        return flowCoeff / (2.0 * 995.0 * areaFlow * areaFlow) //j(i(h(g(f(x)))))
        // j(i(h(g(f(x)))))' = j'(i(h(g(f(x)))))i'(h(g(f(x)))h'(g(f(x)))g'(f(x))f'(x)
    }

    //@_silgen_name("RadiantSlab_deltaP")
    @differentiable(reverse)
    func deltaP(_ x: Scalar) -> Scalar {
        -inlineResistanceFluid(x) * x * x
    }
}

// MARK: SubSystem

private struct Subsystem { // Simple loop like Basic Load Matching
    private let boiler = Boiler()
    private let pump = Pump()
    private let load = RadiantSlab()

    init() {

    }

    //@_silgen_name("Subsystem_totalDeltaP")
    @differentiable(reverse)
    func totalDeltaP(_ flow: Scalar) -> Scalar {
        boiler.deltaP(flow) + pump.deltaP(flow) + load.deltaP(flow)
    }
}

@inline(never)
private func hardyCrossSimpleLoop(initialFlow: Scalar, subsystem: Subsystem) -> Scalar {
    let maxIters = 20
    let balanceTol: Scalar = 1e-6
    let stepTol: Scalar = 1e-6
//    let delta: Scalar = 1e-4

    var flow = initialFlow

    @inline(never)
    func valueWithGradientWrapper(at: Scalar) -> (Scalar, Scalar) {
        return valueWithGradient(at: at, of: subsystem.totalDeltaP)
    }

    for _ in 0 ..< maxIters {
        // Auto diff
        let (f, Df) = valueWithGradient(at: flow, of: subsystem.totalDeltaP)
        //let (f, Df) = valueWithGradientWrapper(at: flow)

        let step = -f / Df
        flow += step
        guard abs(step) > stepTol, abs(f) > balanceTol else { break }
    }
    return flow
}

private let subsystem = Subsystem()
private let nominalFlow: Scalar = 0.25 // Pre-solve guess at flow
private var flow: Scalar = 0

flow += hardyCrossSimpleLoop(initialFlow: nominalFlow, subsystem: subsystem)
print(flow)

Expected behavior

Code could be optimized down to simple set of arithmetic operations in both cases.

Environment

Swift version 6.1-dev (LLVM fcc20a24e57c484, Swift f802b67fc06447f) Target: arm64-apple-macosx13.0

Additional information

No response

asl commented 4 hours ago

Tagging @eeckstein @nate-chandler @atrick

asl commented 4 hours ago

@JaapWijnen