swiftlang / swift

The Swift Programming Language
https://swift.org
Apache License 2.0
67.28k stars 10.33k forks source link

[SR-15793] [AutoDiff] Incorrect behavior with derivatives #58070

Open philipturner opened 2 years ago

philipturner commented 2 years ago
Previous ID SR-15793
Radar None
Original Reporter @philipturner
Type Bug
Additional Detail from JIRA | | | |------------------|-----------------| |Votes | 0 | |Component/s | | |Labels | Bug | |Assignee | None | |Priority | Medium | md5: 7bdb8a5d99f3bfcb51ba371fd5018a93

Issue Description:

Automatic differentiation gives incorrect results when differentiating a mutating function. It swizzles the components of the gradient. This function returns the following for derivatives:

import _Differentiation

extension Double {
  func addingThree(_ lhs: Self, _ mhs: Self, _ rhs: Self) -> Self {
    self + lhs + rhs
  }

  @derivative(of: addingThree)
  func _vjpAddingThree(
    _ lhs: Self,
    _ mhs: Self,
    _ rhs: Self
  ) -> (value: Self, pullback: (Self) -> (Self, Self, Self, Self)) {
    return (addingThree(lhs, mhs, rhs), { v in (v, lhs, mhs, rhs) })
  }

  mutating func addThree(_ lhs: Self, _ mhs: Self, _ rhs: Self) {
    self += lhs + mhs + rhs
  }

  @derivative(of: addThree)
  mutating func _vjpAddThree(
    _ lhs: Self,
    _ mhs: Self,
    _ rhs: Self
  ) -> (value: Void, pullback: (inout Self) -> (Self, Self, Self)) {
    addThree(lhs, mhs, rhs)
    return ((), { v in (lhs, mhs, rhs) })
  }
}

@differentiable(reverse)
func altAddingThree(_ x: Double, _ y: Double, _ z: Double, _ w: Double) -> Double {
  var output = x
  output.addThree(y, z, w)
  return output
}

assert((2, 3, 4) == gradient(at: 2, 3, 4, of: { 10.addingThree($0, $1, $2) }))

// fails
assert((2, 3, 4) == gradient(at: 2, 3, 4, of: { altAddingThree(10, $0, $1, $2) }))
input expected output
(x=10, y=2, z=3) (dx=1, dy=2, dz=3) (dx=1, dy=3, dz=4)
(y=2, x=10, z=3) (dy=2, dx=1, dz=3) (dy=3, dx=1, dz=4)
(y=2, z=3, x=10) (dy=2, dz=3, dx=1) (dy=3, dz=4, dx=1)
(y=2, z=3, w=4) (dy=2, dz=3, dw=4) (dy=3, dz=4, dw=2)
(x=10, y=2) (dx=1, dy=2) (dx=1, dy=3)
philipturner commented 2 years ago

Fix submitted as https://github.com/apple/swift/pull/58437

BradLarson commented 1 year ago

This behavior is still present on top-of-tree Swift, and we need to investigate this further, so reopening to look into it.

jkshtj commented 4 months ago

Issue still exists in 05/24 toolchain.