Closed porterchild closed 4 years ago
Hi @porterchild,
The error Referenced declaration 'max' could not be resolved
is emitted during @derivative(of: max)
type-checking and may happen due to one of a few different reasons. Could you please share a full reproducer?
Sure,
import _Differentiation
@derivative(of: max)
func maxDerivative(_ x: Double, _ y: Double) -> (value: Double, pullback: (Double) -> (Double, Double)){
func pullback(_ vector: Double) -> (Double, Double){
if x < y{
return (0.0, vector)
}
else{
return (vector, 0.0)
}
}
return (value: max(x, y), pullback: pullback)
}
print(gradient(of: max)(1, 2))
Also, building on the command line I got a little more information in a compiler note:
Swift.max:1:24: note: candidate global function does not have expected type '(Double, Double) -> Double'
@inlinable public func max<T>(_ x: T, _ y: T) -> T where T : Comparable
which makes sense, I wonder if there is still a way to do it even though Comparable
isn't Differentiable
?
If there isn't, it's easy enough to do by making wrappers around stdlib functions and defining derivatives for the wrappers, I've just used that workaround for long enough that I figured I'd ask if there was a cleaner way.
Oops, I missed the fact that your original post is a full reproducer. Sorry about that.
@derivative(of:)
attribute requires that derivative functions of generic functions must also be generic. However, derivative functions may also have additional differentiability generic requirements, explained here.
Here's how to define a working derivative for max(_:_:)
in the standard library. maxVJP
is generic just like max
, but has an additional T: Differentiable
constraint:
@usableFromInline
@derivative(of: max)
func maxVJP<T: Comparable & Differentiable>(_ x: T, _ y: T) -> (
value: T, pullback: (T.TangentVector) -> (T.TangentVector, T.TangentVector)
) {
func pullback(_ v: T.TangentVector) -> (T.TangentVector, T.TangentVector) {
if x < y {
return (.zero, v)
} else {
return (v, .zero)
}
}
return (value: max(x, y), pullback: pullback)
}
print(gradient(of: max)(1, 2)) // (0.0, 1.0)
print(gradient(of: max)(2, 1)) // (1.0, 0.0)
Cool, exactly what I was hoping for, thanks!
I'm trying to define a derivative for
max(Double, Double) -> Double
, which is part of the standard library. My derivative,results in the compiler error
Referenced declaration 'max' could not be resolved
on the line@derivative(of: max)
. I wonder if I'm missing something? The functionmax
is definitely defined the the current scope. Thanks!