swiftlang / swift

The Swift Programming Language
https://swift.org
Apache License 2.0
67.34k stars 10.34k forks source link

[AutoDiff] Runtime crash when evaluating a pullback more than once. #64257

Closed fibrechannelscsi closed 1 year ago

fibrechannelscsi commented 1 year ago

Description The code shown below will crash at runtime with a SIGABRT / double free.

Steps to reproduce Paste the code below into a single file, compile, and run:

import _Differentiation
public func simWithPullback(htm: MinModel, params: MinParams) -> (value: Output, pullback: (Output.TangentVector) -> (MinParams.TangentVector)){
    let simulationValueAndPullback = valueWithPullback(at: params, of: run)
    return (value: simulationValueAndPullback.value, pullback: simulationValueAndPullback.pullback)
}
@differentiable(reverse)
public func run(params: MinParams) -> Output {
    var arrayOfSIMDs = Array<SIMD8<Float>>(repeating: SIMD8<Float>.one * 3.9, count: 2)
    for t in 0 ... 1 {arrayOfSIMDs.update(i: 0, j: t, with: Float(t))}
    return Output(results: MiniLoop(other: params._twoDArray).results)
}
struct MiniLoop: Differentiable {
    var results: [[Float]]
    var twoDArray: [[Float]]?
    @differentiable(reverse)
    init(results: [[Float]] = [], other: [[Float]]? = nil) {self.results = results; self.twoDArray = other}
}
public struct Output: Differentiable {
    public var results: [[Float]]
    @differentiable(reverse)
    public init(results: [[Float]]) {self.results = results}
}
public struct MinModel: Differentiable {public init(){}}
public struct MinParams: Differentiable {
    public var _twoDArray: [[Float]]?
    public init(){}
}
public protocol Differentiable2: Differentiable {var dView: TangentVector { get }}
public extension Differentiable2 {var dView: TangentVector {return TangentVector.zero}}
extension Float: Differentiable2 {}
extension Array: Differentiable2 where Element: Differentiable2 {}
public extension Array where Element == SIMD8<Float> {
    @inlinable
    @differentiable(reverse)
    mutating func update(i: Int, j: Int, with newValue: Element.Scalar) {self[i][j] = newValue}
}
public extension Array where Element == SIMD8<Float> {
    @inlinable
    @derivative(of: update)
    mutating func vjpUpdate(i: Int, j: Int, with newValue: Element.Scalar) -> (value: Void, pullback: (inout TangentVector) -> (Element.Scalar.TangentVector)) {
        self.update(i: i, j: j, with: newValue)
        let iCount = self.count
        return ((), { v in if v.base.count < iCount {v.base = [Element.TangentVector](repeating: .zero, count: iCount)}; let dElement = v[i][j]; v.base[i][j] = .zero; return dElement})
    }
}
let valueAndPullback = simWithPullback(htm: MinModel(), params: MinParams())
let output = valueAndPullback.value
let resultOnes = output.results.map { $0.map { _ in Float(1) }}
var grad = valueAndPullback.pullback(Output.TangentVector(results: resultOnes.dView))
grad = valueAndPullback.pullback(Output.TangentVector(results: resultOnes.dView))
print(grad)

Expected behavior The program should print at least TangentVector(_twoDArray: []), then exit with an exit code of 0.

Environment

Additional context Running the pullback just once (i.e., removing the line immediately above print(grad)) will cause the program to exit without error. Further, omitting the line with arrayOfSIMDs.update will also cause the program to exit without error. It's likely there is memory corruption occurring in and around update(); I have seen iCount be set to some large, nonsensical integer on occasion.

BradLarson commented 1 year ago

cc @asl

asl commented 1 year ago

So, the crash itself is here:

sil shared [transparent] [thunk] @$s2pb8MiniLoopV13TangentVectorVSa16_DifferentiationAF14DifferentiableRzlE0G4ViewVyAIySf_G_GSqA2fGRzlEADVySaySaySfGG_GIegnor_TJSpSSUpSrUSUP : $@convention(thin) (@in_guaranteed MiniLoop.TangentVector, @guaranteed @callee_guaranteed (@in_guaranteed MiniLoop.TangentVector) -> (@owned Array<Array<Float>.DifferentiableView>.DifferentiableView, @out Optional<Array<Array<Float>>>.TangentVector)) -> @out Optional<Array<Array<Float>>>.TangentVector {
// %0                                             // user: %3
// %1                                             // user: %3
// %2                                             // user: %3
bb0(%0 : $*Optional<Array<Array<Float>>>.TangentVector, %1 : $*MiniLoop.TangentVector, %2 : $@callee_guaranteed (@in_guaranteed MiniLoop.TangentVector) -> (@owned Array<Array<Float>.DifferentiableView>.DifferentiableView, @out Optional<Array<Array<Float>>>.TangentVector)):
  %3 = apply %2(%0, %1) : $@callee_guaranteed (@in_guaranteed MiniLoop.TangentVector) -> (@owned Array<Array<Float>.DifferentiableView>.DifferentiableView, @out Optional<Array<Array<Float>>>.TangentVector) // user: %4
  release_value %3 : $Array<Array<Float>.DifferentiableView>.DifferentiableView // id: %4
  %5 = tuple ()                                   // user: %6
  return %5 : $()                                 // id: %6
} // end sil function '$s2pb8MiniLoopV13TangentVectorVSa16_DifferentiationAF14DifferentiableRzlE0G4ViewVyAIySf_G_GSqA2fGRzlEADVySaySaySfGG_GIegnor_TJSpSSUpSrUSUP'

The address of %2 that got applied is wrong in the second time. The thunk itself is partially applied in autodiff subset parameters thunk for reverse-mode derivative from MiniLoop.init(results:other:). And everything it recursively partially applied 2 or 3 times at this point :)

asl commented 1 year ago

And everything reproduces if run is reduced down to:

@differentiable(reverse)
public func run(params: MinParams) -> Output {
    for t in 0 ... 1 {
    }
    let res = MiniLoop(other: params._twoDArray).results
    return Output(results: res)
}

loop is essential here.

asl commented 1 year ago

So, the address of %2 is 0x08d0eef4ea1dadab. This is the particular pattern that is used feel the freed memory. So we're having some use already freed heap object.

asl commented 1 year ago

Here is the relevant code in question:

  %41 = builtin "autoDiffProjectTopLevelSubcontext"(%2 : $Builtin.NativeObject) : $Builtin.RawPointer // user: %42
  %42 = pointer_to_address %41 : $Builtin.RawPointer to [strict] $*(predecessor: _AD__$s2pb3run6paramsAA6OutputVAA9MinParamsV_tF_bb3__Pred__src_0_wrt_0, @callee_guaranteed (@in_guaranteed MiniLoop.TangentVector) -> @out Optional<Array<Array<Float>>>.TangentVector, @callee_guaranteed (@guaranteed Output.TangentVector) -> @owned Array<Array<Float>.DifferentiableView>.DifferentiableView) // user: %43
  %43 = load %42 : $*(predecessor: _AD__$s2pb3run6paramsAA6OutputVAA9MinParamsV_tF_bb3__Pred__src_0_wrt_0, @callee_guaranteed (@in_guaranteed MiniLoop.TangentVector) -> @out Optional<Array<Array<Float>>>.TangentVector, @callee_guaranteed (@guaranteed Output.TangentVector) -> @owned Array<Array<Float>.DifferentiableView>.DifferentiableView) // users: %46, %45, %44
  %44 = tuple_extract %43 : $(predecessor: _AD__$s2pb3run6paramsAA6OutputVAA9MinParamsV_tF_bb3__Pred__src_0_wrt_0, @callee_guaranteed (@in_guaranteed MiniLoop.TangentVector) -> @out Optional<Array<Array<Float>>>.TangentVector, @callee_guaranteed (@guaranteed Output.TangentVector) -> @owned Array<Array<Float>.DifferentiableView>.DifferentiableView), 0 // user: %89
  %45 = tuple_extract %43 : $(predecessor: _AD__$s2pb3run6paramsAA6OutputVAA9MinParamsV_tF_bb3__Pred__src_0_wrt_0, @callee_guaranteed (@in_guaranteed MiniLoop.TangentVector) -> @out Optional<Array<Array<Float>>>.TangentVector, @callee_guaranteed (@guaranteed Output.TangentVector) -> @owned Array<Array<Float>.DifferentiableView>.DifferentiableView), 1 // users: %70, %69
  %46 = tuple_extract %43 : $(predecessor: _AD__$s2pb3run6paramsAA6OutputVAA9MinParamsV_tF_bb3__Pred__src_0_wrt_0, @callee_guaranteed (@in_guaranteed MiniLoop.TangentVector) -> @out Optional<Array<Array<Float>>>.TangentVector, @callee_guaranteed (@guaranteed Output.TangentVector) -> @owned Array<Array<Float>.DifferentiableView>.DifferentiableView), 2 // users: %48, %47
  %47 = apply %46(%1) : $@callee_guaranteed (@guaranteed Output.TangentVector) -> @owned Array<Array<Float>.DifferentiableView>.DifferentiableView // users: %88, %53, %51, %50
  strong_release %46 : $@callee_guaranteed (@guaranteed Output.TangentVector) -> @owned Array<Array<Float>.DifferentiableView>.DifferentiableView // id: %48

Essentially %48 releases %46 and it got freed, so we're having use-after-free for the second try

fibrechannelscsi commented 1 year ago

Thanks for taking a look!