Open fibrechannelscsi opened 10 months ago
@fibrechannelscsi I've further reduced the reproducer to -
@differentiable(reverse)
func r(b: S) -> S {
let _ = withoutDerivative(at: b.u)
return b
}
public struct S {
public var u: any Differentiable
}
extension S: Differentiable {
public typealias TangentVector = Float
public mutating func move(by offset: Float) {}
}
After this reduction the issue becomes more apparent -
main.swift:3:2: error: function is not differentiable
@differentiable(reverse)
~^~~~~~~~~~~~~~~~~~~~~~~
main.swift:4:6: note: when differentiating this function definition
func r(b: S) -> S {
^
main.swift:5:37: note: cannot differentiate access to property 'S.u' because property type 'any Differentiable' does not conform to 'Differentiable'
let _ = withoutDerivative(at: b.u)
^
Notice the last "note" specifically.
because property type 'any Differentiable' does not conform to 'Differentiable'
Also confirmed that this is the right behavior from the swift-book.
Another problem with this approach is that the shape transformations don’t nest. The result of flipping a triangle is a value of type Shape, and the protoFlip(:) function takes an argument of some type that conforms to the Shape protocol. However, a value of a boxed protocol type doesn’t conform to that protocol; the value returned by protoFlip(:) doesn’t conform to Shape. This means code like protoFlip(protoFlip(smallTriangle)) that applies multiple transformations is invalid because the flipped shape isn’t a valid argument to protoFlip(_:).
That said, I have some questions around the behavior of withoutderivative
.
As you pointed out. Modifying
let _ = withoutDerivative(at: b.u)
to
let _ = withoutDerivative(at: b).u
fixes things and the code compiles.
The original error associated to this code said -
main.swift:5:37: note: cannot differentiate access to property 'S.u' because property type 'any Differentiable' does not conform to 'Differentiable'
let _ = withoutDerivative(at: b.u)
This makes sense as on the SIL level what is happening is that we are trying to differentiate u
from b.u
and we know that u
is not differentiable.
May be we need to write code differently in this case?
Could you explain the semantics of withoutDerivative
? I've always struggled a bit to understand this.
I've had a dig through the documentation a little bit, and found these: https://www.tensorflow.org/swift/tutorials/custom_differentiation https://swiftinit.org/docs/swift/_differentiation.withoutderivative(at:) Largely, it stops the derivates from propagating.
Description The compiler is marking access to a certain properties as not
Differentiable
in certain cases. In this particular case, the underlying property (protocol) is explicitly marked as conforming toDifferentiable
.Steps to reproduce Paste the following code into a new project (or
main.swift
file) and compile in Debug mode. Either a clean Build in Xcode or command-lineswift main.swift
will work.The error message provided is:
Expected behavior The compilation should succeed, or, the compiler should provide a more detailed error message as to why compilation cannot succeed. Note that
I
is marked as being differentiable in line 12.Environment
Additional context Modifying line 5 in the above reproducer to:
causes the build to succeed. Here, we consider
withoutDerivative
of(at: w)
instead of(at: w.u)
.Moving the
Differentiable
conformance ofS
to the struct itself, that is, changing lines 9 and 10 to:does not appear to change the behaviors described above.