swiftlang / swift

The Swift Programming Language
https://swift.org
Apache License 2.0
67.31k stars 10.34k forks source link

[Autodiff] Unable to differentiate access to certain properties marked as `Differentiable`. #69513

Open fibrechannelscsi opened 10 months ago

fibrechannelscsi commented 10 months ago

Description The compiler is marking access to a certain properties as not Differentiable in certain cases. In this particular case, the underlying property (protocol) is explicitly marked as conforming to Differentiable.

Steps to reproduce Paste the following code into a new project (or main.swift file) and compile in Debug mode. Either a clean Build in Xcode or command-line swift main.swift will work.

import Foundation; import _Differentiation
public struct E: Differentiable {var w: S}
@differentiable(reverse) func r(b: Y) -> E {
    let w = b.w
    let y = withoutDerivative(at: w.u).v(); print(y)
    return E(w: w)
}
public struct Y: Differentiable & H {@differentiable(reverse) public var w: S}
public struct S {public var u: any I}
extension S: Differentiable {public typealias TangentVector = G; public struct G: Differentiable {}; public mutating func move(by offset: G) {}}
extension S.G: AdditiveArithmetic {}
public protocol I : Differentiable & C {}
public protocol C {@noDerivative func v() -> [String]}
public protocol H: Differentiable {var z: () -> TangentVector { get }}
public extension H {var z: () -> TangentVector {{ Self.TangentVector.zero }}}

The error message provided is:

main.swift:3:2: error: function is not differentiable
@differentiable(reverse) func r(b: Y) -> E {
~^~~~~~~~~~~~~~~~~~~~~~~
main.swift:3:31: note: when differentiating this function definition
@differentiable(reverse) func r(b: Y) -> E {
                              ^
main.swift:5:37: note: cannot differentiate access to property 'S.u' because property type 'any I' does not conform to 'Differentiable'
    let y = withoutDerivative(at: w.u).v(); print(y)

Expected behavior The compilation should succeed, or, the compiler should provide a more detailed error message as to why compilation cannot succeed. Note that I is marked as being differentiable in line 12.

Environment

Additional context Modifying line 5 in the above reproducer to:

let y = withoutDerivative(at: w).u.v(); print(y)

causes the build to succeed. Here, we consider withoutDerivative of (at: w) instead of (at: w.u).

Moving the Differentiable conformance of S to the struct itself, that is, changing lines 9 and 10 to:

public struct S: Differentiable {public var u: any I}
extension S {public typealias TangentVector = G; public struct G: Differentiable {}; public mutating func move(by offset: G) {}}

does not appear to change the behaviors described above.

jkshtj commented 10 months ago

@fibrechannelscsi I've further reduced the reproducer to -

@differentiable(reverse) 
func r(b: S) -> S {
    let _ = withoutDerivative(at: b.u)
    return b
}

public struct S {
    public var u: any Differentiable
}

extension S: Differentiable {
    public typealias TangentVector = Float 
    public mutating func move(by offset: Float) {}
}

After this reduction the issue becomes more apparent -

main.swift:3:2: error: function is not differentiable
@differentiable(reverse)
~^~~~~~~~~~~~~~~~~~~~~~~
main.swift:4:6: note: when differentiating this function definition
func r(b: S) -> S {
     ^
main.swift:5:37: note: cannot differentiate access to property 'S.u' because property type 'any Differentiable' does not conform to 'Differentiable'
    let _ = withoutDerivative(at: b.u)
                                    ^

Notice the last "note" specifically.

because property type 'any Differentiable' does not conform to 'Differentiable'

Also confirmed that this is the right behavior from the swift-book.

Another problem with this approach is that the shape transformations don’t nest. The result of flipping a triangle is a value of type Shape, and the protoFlip(:) function takes an argument of some type that conforms to the Shape protocol. However, a value of a boxed protocol type doesn’t conform to that protocol; the value returned by protoFlip(:) doesn’t conform to Shape. This means code like protoFlip(protoFlip(smallTriangle)) that applies multiple transformations is invalid because the flipped shape isn’t a valid argument to protoFlip(_:).

jkshtj commented 10 months ago

That said, I have some questions around the behavior of withoutderivative.

As you pointed out. Modifying

let _ = withoutDerivative(at: b.u)

to

let _ = withoutDerivative(at: b).u

fixes things and the code compiles.

The original error associated to this code said -

main.swift:5:37: note: cannot differentiate access to property 'S.u' because property type 'any Differentiable' does not conform to 'Differentiable'
    let _ = withoutDerivative(at: b.u)

This makes sense as on the SIL level what is happening is that we are trying to differentiate u from b.u and we know that u is not differentiable.

May be we need to write code differently in this case? Could you explain the semantics of withoutDerivative? I've always struggled a bit to understand this.

fibrechannelscsi commented 10 months ago

I've had a dig through the documentation a little bit, and found these: https://www.tensorflow.org/swift/tutorials/custom_differentiation https://swiftinit.org/docs/swift/_differentiation.withoutderivative(at:) Largely, it stops the derivates from propagating.