tensorflow / swift

Swift for TensorFlow
https://tensorflow.org/swift
Apache License 2.0
6.12k stars 608 forks source link

Can't define derivative for a function that definitely exists #527

Closed porterchild closed 4 years ago

porterchild commented 4 years ago

I'm trying to define a derivative for max(Double, Double) -> Double, which is part of the standard library. My derivative,

@derivative(of: max)
func maxDerivative(_ x: Double, _ y: Double) -> (value: Double, pullback: (Double) -> (Double, Double)){
    func pullback(_ vector: Double) -> (Double, Double){
        if x < y{
            return (0.0, vector)
        }
        else{
            return (vector, 0.0)
        }
    }
    return (value: max(x, y), pullback: pullback)
}

results in the compiler error Referenced declaration 'max' could not be resolved on the line @derivative(of: max). I wonder if I'm missing something? The function max is definitely defined the the current scope. Thanks!

dan-zheng commented 4 years ago

Hi @porterchild,

The error Referenced declaration 'max' could not be resolved is emitted during @derivative(of: max) type-checking and may happen due to one of a few different reasons. Could you please share a full reproducer?

porterchild commented 4 years ago

Sure,

import _Differentiation

@derivative(of: max)
func maxDerivative(_ x: Double, _ y: Double) -> (value: Double, pullback: (Double) -> (Double, Double)){
    func pullback(_ vector: Double) -> (Double, Double){
        if x < y{
            return (0.0, vector)
        }
        else{
            return (vector, 0.0)
        }
    }
    return (value: max(x, y), pullback: pullback)
}

print(gradient(of: max)(1, 2))

Also, building on the command line I got a little more information in a compiler note:

Swift.max:1:24: note: candidate global function does not have expected type '(Double, Double) -> Double'
@inlinable public func max<T>(_ x: T, _ y: T) -> T where T : Comparable

which makes sense, I wonder if there is still a way to do it even though Comparable isn't Differentiable?

porterchild commented 4 years ago

If there isn't, it's easy enough to do by making wrappers around stdlib functions and defining derivatives for the wrappers, I've just used that workaround for long enough that I figured I'd ask if there was a cleaner way.

dan-zheng commented 4 years ago

Oops, I missed the fact that your original post is a full reproducer. Sorry about that.

@derivative(of:) attribute requires that derivative functions of generic functions must also be generic. However, derivative functions may also have additional differentiability generic requirements, explained here.

Here's how to define a working derivative for max(_:_:) in the standard library. maxVJP is generic just like max, but has an additional T: Differentiable constraint:

@usableFromInline
@derivative(of: max)
func maxVJP<T: Comparable & Differentiable>(_ x: T, _ y: T) -> (
  value: T, pullback: (T.TangentVector) -> (T.TangentVector, T.TangentVector)
) {
    func pullback(_ v: T.TangentVector) -> (T.TangentVector, T.TangentVector) {
        if x < y {
            return (.zero, v)
        } else {
            return (v, .zero)
        }
    }
    return (value: max(x, y), pullback: pullback)
}

print(gradient(of: max)(1, 2)) // (0.0, 1.0)
print(gradient(of: max)(2, 1)) // (1.0, 0.0)
porterchild commented 4 years ago

Cool, exactly what I was hoping for, thanks!