swiftlang / swift

The Swift Programming Language
https://swift.org
Apache License 2.0
67.33k stars 10.34k forks source link

[SR-13096] Fix autodiff typing rules for class-typed function parameters #55542

Open dan-zheng opened 4 years ago

dan-zheng commented 4 years ago
Previous ID SR-13096
Radar rdar://problem/72819053
Original Reporter @dan-zheng
Type Bug
Additional Detail from JIRA | | | |------------------|-----------------| |Votes | 0 | |Component/s | Compiler | |Labels | Bug | |Assignee | None | |Priority | Medium | md5: bece69c35790a09fc5e98c37400842e1

relates to:

Issue Description:

Fix @differentiable and @derivative type-checking for class-typed function parameters.

Class-typed values have reference semantics and can be freely mutated. Thus, they should be treated like inout parameters for the purposes of @differentiable and @derivative type-checking.


Example:

import _Differentiation

class Class: Differentiable {
  var x: Float = 0

  // Semantically acts like `Class.x.set`.
  // Type: `(Class) -> (Float) -> Void`.
  func setterForX(_ newValue: Float) {
    self.x = newValue
  }

  // This VJP is expected to pass type-checking but does not.
  @derivative(of: setterForX, wrt: (self, newValue))
  func vjpSetterForX(_ newValue: Float) -> (
    value: (), pullback: (inout TangentVector) -> Float
  ) {
    fatalError()
  }
}

Unexpected error, because the class-typed self parameter is not treated as a "semantic result" by autodiff::getFunctionSemanticResultTypes.

class.swift:13:4: error: cannot differentiate void function 'setterForX'
  @derivative(of: setterForX, wrt: (self, newValue))
   ^              ~~~~~~~~~~
rxwei commented 3 years ago

@swift-ci create

asl commented 1 year ago

@dan-zheng Why the pullback in the testcase has (inout TangentVector, Float) -> Void function type? Shouldn't it be (inout TangentVector) -> Float as we're having wrt for both arguments?

dan-zheng commented 1 year ago
Original function: `setterForX`
VJP function: `vjpSetterForX`

Original type: (inout Self, Float) -> Void
Differentiable wrt: (0, 1)
Pullback type: (inout Self.TangentVector) -> Float

Yes, I think you're right. I'll update the issue description, thanks.

asl commented 1 year ago

Great, thanks. I'm having locally, so this will be a part of larger PR:

// reverse-mode derivative of Class.setterForX(_:)
sil hidden [thunk] [always_inline] @$s4main5ClassC10setterForXyySfFTJrSSpSr : $@convention(method) (Float, @guaranteed Class) -> @owned @callee_guaranteed (@inout Class.TangentVector) -> Float {
// %0                                             // user: %3
// %1                                             // user: %3
bb0(%0 : $Float, %1 : $Class):
  // function_ref Class.vjpSetterForX(_:)
  %2 = function_ref @$s4main5ClassC13vjpSetterForXyyt5value_SfAC13TangentVectorVzc8pullbacktSfF : $@convention(method) (Float, @guaranteed Class) -> @owned @callee_guaranteed (@inout Class.TangentVector) -> Float
 // user: %3
  %3 = apply %2(%0, %1) : $@convention(method) (Float, @guaranteed Class) -> @owned @callee_guaranteed (@inout Class.TangentVector) -> Float // user: %5
  // function_ref autodiff self-reordering reabstraction thunk for @escaping @callee_guaranteed (@inout Class.TangentVector) -> (@unowned Float)
  %4 = function_ref @$s4main5ClassC13TangentVectorVSfIegld_AESfIegld_TJOp : $@convention(thin) (@inout Class.TangentVector, @guaranteed @callee_guaranteed (@inout Class.TangentVector) -> Float) -> Float // user: %5
  %5 = partial_apply [callee_guaranteed] %4(%3) : $@convention(thin) (@inout Class.TangentVector, @guaranteed @callee_guaranteed (@inout Class.TangentVector) -> Float) -> Float // user: %6
  return %5 : $@callee_guaranteed (@inout Class.TangentVector) -> Float // id: %6
} // end sil function '$s4main5ClassC10setterForXyySfFTJrSSpSr'