swiftlang / swift

The Swift Programming Language
https://swift.org
Apache License 2.0
67.3k stars 10.33k forks source link

[AutoDiff] Runtime EXC_BAD_ACCESS when running a pullback in certain cases. #73526

Open fibrechannelscsi opened 4 months ago

fibrechannelscsi commented 4 months ago

Description

The code below generates a runtime EXC_BAD_ACCESS when executed.

Reproduction

Copy and paste the following code into a new project, and build in Debug mode.

import Foundation; import _Differentiation
public protocol N {}
public protocol H: Differentiable {}
public protocol D: H & N where Self.TangentVector: N {}
public struct O<V> {var v: V}
extension O: Differentiable where V: Differentiable {}
public extension N {var d: [String : PartialKeyPath<Self>] {return self.e(i: self) as! [String : PartialKeyPath<Self>]}}
public extension H {var z: TangentVector {Self.TangentVector.zero}}
public struct R: D {public var a: Double = 7}
public struct B<F, G> {
    @noDerivative public internal(set) var s: [String : WritableKeyPath<F, G>]
    @noDerivative public var a: Set<WritableKeyPath<F, G>> {var a = Set<WritableKeyPath<F, G>>(); a.insert(self.s.values.first!); return a}
    public init<S: Sequence>(to: S) where S.Element == WritableKeyPath<F, G> {self.s = ["0" : \R.a] as! [String : WritableKeyPath<F, G>]}
    @differentiable(reverse where F: D, G: Differentiable) public func o(f i: F) -> O<G> {
        var y: [G] = []
        for k in self.a {y = y + [p(i, at: k)]}
        var b = [O<G>]()
        for i in 0 ..< withoutDerivative(at: y.count) {
            b.append(O<G>(v: y[i]))
           // b = b + [O<G>(v: y[i])]
        }
        if withoutDerivative(at: b.count) > 1 {return O<G>(v: b[0].v)}
        else {return O<G>(v: y[0])}
    }
}
@inlinable @differentiable(reverse where W: D, M: Differentiable) public func p<W, M>(_ o: W, at j: WritableKeyPath<W, M>) -> M {return o[keyPath: j]}
@inlinable @derivative(of: p) public func vjpP<O, M>(_ o: O, at m: WritableKeyPath<O, M>) -> (value: M, pullback: (M.TangentVector) -> O.TangentVector) where O: D, M: Differentiable
{
    var z = o.z
    let r = o.d.mapValues { $0 as? WritableKeyPath<O, M> }.compactMapValues { $0 }.filter { $0.value == m }.compactMap { $0.key }
    let u = z.d.filter {r.first! == $0.key}.values.map { $0 } as! [WritableKeyPath<O.TangentVector, M.TangentVector>]
    return (value: o[keyPath: m], pullback: { d in z[keyPath: u.first!] = d; return z})
}
extension N {func e(i: any N) -> [String : AnyKeyPath] {if g == 0 {g += 1; return ["a": \R.a]}; return ["a" : \R.TangentVector.a]}}
var g = 0;
let b = R(a: 9)
@differentiable(reverse) func s(b: R) -> Double {return B<R, Double>(to: [\R.a]).o(f: b).v * 3}
print(valueWithGradient(at: b, of: s))

If run in Xcode, the error will be highlighted in line 19, which contains b.append(O<G>(v: y[i])). Here, it will indicate an "index out of range" error. If built via swift build and executed in lldb, an EXC_BAD_ACCESS will be triggered with an attempt to access 0x0. The stack trace for this is listed below.

Expected behavior

The program should print: (value: 27.0, gradient: i2.R.TangentVector(a: 3.0)) and exit with an exit code of 0. Here, i2 is the name of the current project.

Environment

M1 Mac Every toolchain I've tried exhibits this, from 2023-07-10a nightly all the way to 2024-05-01a. I can even reproduce this with the 6.0 2024-04-29a snapshot.

Additional information

Stack trace when run in lldb. Note that the executable is simply called a.

* thread #1, queue = 'com.apple.main-thread', stop reason = EXC_BAD_ACCESS (code=1, address=0x0)
  * frame #0: 0x0000000000000000
    frame #1: 0x0000000100155fa4 a`reverse-mode derivative of B.o(i=<unavailable>, self=_9999.B<_9999.R, Swift.Double> @ 0x000000016fdff018) at main.swift:14:72
    frame #2: 0x000000010015a1b0 a`reverse-mode derivative of s(b=(a = 9)) at main.swift:37:82
    frame #3: 0x0000000100152790 a`$s5_99991RVS2dAC13TangentVectorVIegyd_Igydo_ACSdxq_Ri_zRi0_zRi__Ri0__r0_lySdAEIsegnr_Iegnro_TR at <compiler-generated>:0
    frame #4: 0x0000000227a92f98 libswift_Differentiation.dylib`_Differentiation.valueWithPullback<τ_0_0, τ_0_1 where τ_0_0: _Differentiation.Differentiable, τ_0_1: _Differentiation.Differentiable>(at: τ_0_0, of: @differentiable(reverse) (τ_0_0) -> τ_0_1) -> (value: τ_0_1, pullback: (τ_0_1.TangentVector) -> τ_0_0.TangentVector) + 152
    frame #5: 0x0000000227a94780 libswift_Differentiation.dylib`_Differentiation.valueWithGradient<τ_0_0, τ_0_1 where τ_0_0: _Differentiation.Differentiable, τ_0_1: Swift.FloatingPoint, τ_0_1: _Differentiation.Differentiable, τ_0_1 == τ_0_1._Differentiation.Differentiable.TangentVector>(at: τ_0_0, of: @differentiable(reverse) (τ_0_0) -> τ_0_1) -> (value: τ_0_1, gradient: τ_0_0.TangentVector) + 184
    frame #6: 0x0000000100152494 a`_9999_main at main.swift:38:7
    frame #7: 0x000000018665d0e0 dyld`start + 2360

If we perform the append operation via b = b + [O<G>(v: y[i])], which is commented out in line 20, then the runtime error will not occur.

fibrechannelscsi commented 4 months ago

@jkshtj