swiftlang / swift

The Swift Programming Language
https://swift.org
Apache License 2.0
67.28k stars 10.33k forks source link

[SR-14228] [AutoDiff] "Curry thunk" differentiation regression #54819

Open dan-zheng opened 4 years ago

dan-zheng commented 4 years ago
Previous ID SR-14228
Radar None
Original Reporter @dan-zheng
Type Bug
Additional Detail from JIRA | | | |------------------|-----------------| |Votes | 0 | |Component/s | Compiler | |Labels | Bug, AutoDiff | |Assignee | None | |Priority | Medium | md5: 8635dba187654182d9292136d60a56fd

relates to:

Issue Description:

Curry thunks were recently rewritten as implicit AST closures instead of SILGen'd thunks: https://github.com/apple/swift/pull/28698.

This caused regressions in curry thunk differentiation. Extracted from test/AutoDiff/downstream/generics.swift:

// TF-688: Test generic curry thunk cloning.
public struct TF_688_Struct<Scalar> {
  var x: Scalar
}
extension TF_688_Struct: Differentiable where Scalar: Differentiable {
  @differentiable
  public static func id(x: Self) -> Self {
    return x
  }
}
@differentiable(wrt: x)
public func TF_688<Scalar: Differentiable>(
  _ x: TF_688_Struct<Scalar>,
  reduction: @differentiable (TF_688_Struct<Scalar>) -> TF_688_Struct<Scalar> = TF_688_Struct.id
) -> TF_688_Struct<Scalar> {
  reduction(x)
}

Before: no error.

// default argument 1 of TF_688<A>(_:reduction:)
sil non_abi [serialized] [ossa] @$s4main6TF_688_9reductionAA0B11_688_StructVyxGAF_A2FXFts14DifferentiableRzlFfA0_ : $@convention(thin) <Scalar where Scalar : Differentiable> () -> @owned @differentiable @callee_guaranteed (@in_guaranteed TF_688_Struct<Scalar>) -> @out TF_688_Struct<Scalar> {
bb0:
  %0 = metatype $@thin TF_688_Struct<Scalar>.Type // user: %2
  // function_ref curry thunk of static TF_688_Struct<A>.id(x:)
  %1 = function_ref @$s4main13TF_688_StructVAAs14DifferentiableRzlE2id1xACyxGAG_tFZTc : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@thin TF_688_Struct<τ_0_0>.Type) -> @owned @callee_guaranteed (@in_guaranteed TF_688_Struct<τ_0_0>) -> @out TF_688_Struct<τ_0_0> // user: %2
  %2 = apply %1<Scalar>(%0) : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@thin TF_688_Struct<τ_0_0>.Type) -> @owned @callee_guaranteed (@in_guaranteed TF_688_Struct<τ_0_0>) -> @out TF_688_Struct<τ_0_0> // user: %3
  %3 = differentiable_function [parameters 0] %2 : $@callee_guaranteed (@in_guaranteed TF_688_Struct<Scalar>) -> @out TF_688_Struct<Scalar> // user: %4
  return %3 : $@differentiable @callee_guaranteed (@in_guaranteed TF_688_Struct<Scalar>) -> @out TF_688_Struct<Scalar> // id: %4
} // end sil function '$s4main6TF_688_9reductionAA0B11_688_StructVyxGAF_A2FXFts14DifferentiableRzlFfA0_'

After: error regarding differentiating fragile function in serialized function.
This error was introduced in https://github.com/apple/swift/pull/28582.

$ swiftc -Xllvm -debug-only=differentiation tf-688.swift
// AD__$s4main6TF_688_9reductionAA0B11_688_StructVyxGAF_A2FXFts14DifferentiableRzlFfA0_A2FcAFmcfu___differentiable_curry_thunk_src_0_wrt_0
sil shared [serialized] @AD__$s4main6TF_688_9reductionAA0B11_688_StructVyxGAF_A2FXFts14DifferentiableRzlFfA0_A2FcAFmcfu___differentiable_curry_thunk_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@thin TF_688_Struct<τ_0_0>.Type) -> @owned @differentiable @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed TF_688_Struct<τ_0_0>) -> @out TF_688_Struct<τ_0_1> for <τ_0_0, τ_0_0> {
// %0                                             // users: %3, %1
bb0(%0 : $@thin TF_688_Struct<τ_0_0>.Type):
  debug_value %0 : $@thin TF_688_Struct<τ_0_0>.Type, let, name "self", argno 1 // id: %1
  // function_ref implicit closure #&#8203;2 in implicit closure #&#8203;1 in default argument 1 of TF_688<A>(_:reduction:)
  %2 = function_ref @$s4main6TF_688_9reductionAA0B11_688_StructVyxGAF_A2FXFts14DifferentiableRzlFfA0_A2FcAFmcfu_A2Fcfu0_ : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed TF_688_Struct<τ_0_0>, @thin TF_688_Struct<τ_0_0>.Type) -> @out TF_688_Struct<τ_0_0> // user: %3
  %3 = partial_apply [callee_guaranteed] %2<τ_0_0>(%0) : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed TF_688_Struct<τ_0_0>, @thin TF_688_Struct<τ_0_0>.Type) -> @out TF_688_Struct<τ_0_0> // user: %4
  %4 = convert_function %3 : $@callee_guaranteed (@in_guaranteed TF_688_Struct<τ_0_0>) -> @out TF_688_Struct<τ_0_0> to $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed TF_688_Struct<τ_0_0>) -> @out TF_688_Struct<τ_0_1> for <τ_0_0, τ_0_0> // user: %5
  %5 = differentiable_function [parameters 0] %4 : $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed TF_688_Struct<τ_0_0>) -> @out TF_688_Struct<τ_0_1> for <τ_0_0, τ_0_0> // user: %6
  return %5 : $@differentiable @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed TF_688_Struct<τ_0_0>) -> @out TF_688_Struct<τ_0_1> for <τ_0_0, τ_0_0> // id: %6
} // end sil function 'AD__$s4main6TF_688_9reductionAA0B11_688_StructVyxGAF_A2FXFts14DifferentiableRzlFfA0_A2FcAFmcfu___differentiable_curry_thunk_src_0_wrt_0'

[AD] Diagnosing non-differentiability.
[AD] For value:
  %4 = convert_function %3 : $@callee_guaranteed (@in_guaranteed TF_688_Struct<τ_0_0>) -> @out TF_688_Struct<τ_0_0> to $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed TF_688_Struct<τ_0_0>) -> @out TF_688_Struct<τ_0_1> for <τ_0_0, τ_0_0> // user: %5
[AD] With invoker:
(differentiation_invoker differentiable_function_inst=(  %5 = differentiable_function [parameters 0] %4 : $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed TF_688_Struct<τ_0_0>) -> @out TF_688_Struct<τ_0_1> for <τ_0_0, τ_0_0> // user: %6
))
tf-688.swift:14:95: error: function is not differentiable
  reduction: @differentiable (TF_688_Struct<Scalar>) -> TF_688_Struct<Scalar> = TF_688_Struct.id
                                                                                ~~~~~~~~~~~~~~^~
tf-688.swift:14:95: note: differentiated functions in '@inlinable' functions must be marked '@differentiable' or have a public '@derivative'; this is not possible with a closure, make a top-level function instead
  reduction: @differentiable (TF_688_Struct<Scalar>) -> TF_688_Struct<Scalar> = TF_688_Struct.id
                                                                                              ^
philipturner commented 2 years ago

@slavapestov I was planning to fix this bug, whose reproducer is here. The end of its stack trace is somewhere in RQM. I was planning to read your entire research paper just to understand what was going on at the end, but that seems like an overblown amount of effort. Is it possible for you to examine the crash a little in LLDB and give me enough of an understanding that I can utilize my experience with other areas of the compiler to fix the bug?

slavapestov commented 2 years ago

Your generic signature is <τ_0_0, τ_0_1, τ_0_2, τ_0_3 where τ_0_0 == τ_0_1, τ_0_2 == τ_0_3>.

τ_0_0.TangentVector is not a valid type parameter in this signature because τ_0_0 does not conform to Differentiable (which is where TangentVector is declared).

That's your bug. The autodiff code is probably forgetting to add a requirement to the signature, which is probably coming from a call to buildGenericSignature() somewhere in the autodiff code.

philipturner commented 2 years ago

@slavapestov you're the best!

For future reference, I have narrowed down the reproducer to something smaller:

import _Differentiation

struct Box<Scalar> {
  var x: Scalar
}

extension Box: Differentiable where Scalar: Differentiable {}

struct Box2<T> {
  var x2: @differentiable(reverse) (Box<T>) -> Box<T>
}
fibrechannelscsi commented 1 year ago

The reproducer posted on 4/28 is still broken with 2023-01-02a through to 2023-01-18a. We get:

Invalid type parameter in getReducedType()
Original type: τ_0_0.TangentVector
Simplified term: τ_0_0.[Differentiable:TangentVector]
Longest valid prefix: τ_0_0
Prefix type: τ_0_0

Requirement machine for <τ_0_0, τ_0_1>
Rewrite system: {
}
}
Property map: {
}
Conformance paths: {
}

and

1.  Apple Swift version 5.9-dev (LLVM 3f23b4ceaf01213, Swift 0763e4b98c74b5b)
2.  Compiling with the current language version
3.  While evaluating request ASTLoweringRequest(Lowering AST to SIL for module smallProject)
4.  While emitting property descriptor for 'x2' (at /Users/user/smallProject/main.swift:10:7)
jkshtj commented 4 months ago

For future reference, I have narrowed down the reproducer to something smaller:

import _Differentiation

struct Box<Scalar> {
  var x: Scalar
}

extension Box: Differentiable where Scalar: Differentiable {}

struct Box2<T> {
  var x2: @differentiable(reverse) (Box<T>) -> Box<T>
}

This reproducer does not cause any errors on 05/24 toolchain. The original reproducer still does.