swiftlang / swift

The Swift Programming Language
https://swift.org
Apache License 2.0
67.52k stars 10.35k forks source link

[SR-9395] [AD] Use JVP and JVP, removing primal and adjoint from syntax #51861

Closed rxwei closed 5 years ago

rxwei commented 5 years ago
Previous ID SR-9395
Radar None
Original Reporter @rxwei
Type Bug
Status Resolved
Resolution Done
Additional Detail from JIRA | | | |------------------|-----------------| |Votes | 0 | |Component/s | Swift for TensorFlow | |Labels | Bug | |Assignee | @rxwei | |Priority | Medium | md5: 8f2c26075d8295c211a2e24c045fe655

Issue Description:

JVP and VJP are the canonical form of derivative functions. Eventually the syntax for making a function differentiable is along these lines:

func foo(x: T) -> T {
  Builtin.nonDiffableImplOfFoo()

  differential(v) wrt x {
    return ...
  }
}
func foo(x: T) -> T {
  Builtin.nonDiffableImplOfFoo()

  pullback(v) wrt x {
    return ...
  }
}

The `differential` and `pullback` statements get lowered to a function in SIL that takes a vector and a boxed checkpoints record containing any checkpoints needed for computing derivatives. The implementation of JVP and VJP partially applies this function to the checkpoints record and returns a differential/pullback. Compiler-emitted "adjoint functions" (as we call it today) will also follow this convention so that they can be turned into a pullback when partially applied.

Today, we need a future-proof implementation that will make a smooth transition from outlining-based primitive registration to inline primitive registration shown above.

Without introducing the `differential` and `pullback` statement syntax, we are working towards this via the following steps:

  1. Reverse the parameter order of the compiler-emitted adjoint function. This allows adjoint functions to be turned into a pullback when partially applied. This is required for synthesizing JVP and VJP functions.
  2. Allow specifying VJP and JVP functions in the `@differentiable` attribute syntax.
  3. Teach the differentiation pass to synthesize VJP functions by partially applying the adjoint to checkpoints. At this point, for compatibility with the old implementation, the checkpointing struct does not include original arguments or original results. So the adjoint function still takes these as separate arguments.
  4. Teach the differentiation pass to use VJP functions. In primal synthesis, replace each active `apply` with an apply of its VJP function. Make `classifyPrimalValues` classify active `apply` instructions as to be checkpointed. Remove the notion of "nested primal values" and its associated checkpointing logic from the implementation.
  5. Remove 'primal:' and 'adjoint:' from the `@differentiable` attribute syntax.
  6. Move original parameters and original results from the adjoint function parameter list into the checkpoints record type.
  7. Remove the `gradient` instruction along with anything associated with it, e.g. the code that synthesizes gradient functions and `SILGradientOptions`.

@rxwei and @marcrasi are driving this.

rxwei commented 5 years ago

All done!