swiftlang / swift

The Swift Programming Language
https://swift.org
Apache License 2.0
67.58k stars 10.36k forks source link

[SR-13166] Default derivative implementations for protocol requirements #54231

Open dan-zheng opened 4 years ago

dan-zheng commented 4 years ago
Previous ID SR-13166
Radar rdar://problem/69987698
Original Reporter @dan-zheng
Type New Feature
Additional Detail from JIRA | | | |------------------|-----------------| |Votes | 1 | |Component/s | Compiler | |Labels | New Feature, AutoDiff | |Assignee | @rxwei | |Priority | Medium | md5: b0201d55480662f7a4792a7ddd4694d2

Sub-Tasks:

Issue Description:

Overview

Default derivative implementations enables protocol requirements (like requirements from AdditiveArithmetic, FloatingPoint, ElementaryFunctions, etc) to be differentiable by default.

See "default derivatives and transposes" from the Differentiable Programming Manifesto for more info.

Example:

// In the standard library:
// public protocol AdditiveArithmetic: Equatable {
//   static func +(lhs: Self, rhs: Self) -> Self
//   ...
// }

extension AdditiveArithmetic where Self: Differentiable {
  @derivative(of: +)
  static func vjpAdd(_ lhs: Self, _ rhs: Self) ->
    (value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)) {
    return (lhs + rhs, { v in (v, v) })
  }
}

This is important for idiomatic protocol-oriented programming to avoid tons of code duplication. Without this feature, all conforming types of the protocol must:


Details

There are two cases to consider:

1. Non-@differentiable protocol requirement.

protocol P {
  func foo(_ x: Float) -> Float
}
extension P {
  @derivative(of: foo)
  func vjpFoo(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
    return (x, { $0 })
  }
}

2. @differentiable protocol requirement.

protocol P {
  @differentiable
  func foo(_ x: Float) -> Float
}
extension P {
  @derivative(of: foo)
  func vjpFoo(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
    return (x, { $0 })
  }
}

Supporting this may require lifting the current restriction that "all implementations of @differentiable protocol requirements must themselves be marked as @differentiable":

protocol AdditiveArithmetic: Equatable {
  static func +(lhs: Self, rhs: Self) -> Self
}

extension AdditiveArithmetic where Self: Differentiable {
  @derivative(of: +)
  static func vjpAdd(_ lhs: Self, _ rhs: Self) ->
    (value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)) {
    return (lhs + rhs, { v in (v, v) })
  }
}

struct Foo: AdditiveArithmetic {
  static func +(lhs: Self, rhs: Self) -> Self {
    lhs
  }
}
deriv.swift:13:8: error: type 'Foo' does not conform to protocol 'AdditiveArithmetic'
struct Foo: AdditiveArithmetic {
       ^
deriv.swift:14:15: note: candidate is missing attribute '@differentiable(wrt: (lhs, rhs) where Self : Differentiable)'
  static func +(lhs: Self, rhs: Self) -> Self {
              ^

SIL default witness table support may be needed.

rxwei commented 4 years ago

@swift-ci create