Closed porterchild closed 4 years ago
as is, it looks like
gradient(of:)
doesn't take inout functions either.
Thanks for the questions! I'll answer your second question first, since it provides context.
It is intentional that there is a minimal set of differential operators, defined for "functionally-typed" functions:
@differentiable (A) -> R
@differentiable (A, B) -> R
@differentiable (A, B, C) -> R
Defining a minimal set of differential operators keeps APIs simple for users - you don't need to remember what it means to apply gradient(of:)
to a function with inout
parameters.
It also reduces API surface area. Otherwise, we'd have to overload differential operators for a huge combination of cases.
For functions with inout
or @noDerivative
parameters, you can use differential operators with them by forming a functionally-typed closure:
// Case (1): function with `@noDerivative` parameters.
// Example: `pow(_ x: Float, _ n: Int) -> Float`.
// `gradient(of: pow)` doesn't work.
// Solution: form a closure capturing the `@noDerivative` parameters.
// This makes sense and is pretty usable!
_ = gradient(of: { (x: Float) in pow(x, 3) })
// Case (2): function with `@inout` parameters.
// Your example:
func square(_ a: inout Float) {
a *= 2
}
// `gradient(of: square)` doesn't work.
// Solution: form a functionally-typed closure creating a temporary variable.
// This is more heavyweight.
func squared(_ a: Float) -> Float {
var tmp = a
square(&tmp)
return tmp
}
_ = gradient(of: squared)
// A nasty one-liner:
_ = gradient(of: { (a: Float) -> Float in var tmp = a; square(&tmp); return tmp })
In practice, I think there are fewer uses for "directly applying differential operators to functions with inout
parameters" than "applying differential operators to functionally-typed functions that internally call mutating functions".
I think the latter results in code that is easier to understand, too.
Regarding an inout
version of makeRecomputedInGradient
function: for your use case, is it acceptable to apply makeRecomputedInGradient
to a functionally-typed wrapper closure?
It's the same technique as above:
// From: https://www.tensorflow.org/swift/tutorials/custom_differentiation#recomputing_activations_during_backpropagation_to_save_memory_checkpointing
func makeRecomputedInGradient<T: Differentiable, U: Differentiable>(
_ original: @escaping @differentiable (T) -> U
) -> @differentiable (T) -> U {
return differentiableFunction { x in
(value: original(x), pullback: { v in pullback(at: x, in: original)(v) })
}
}
// Your example:
func square(_ a: inout Float) {
a *= 2
}
// Solution: form a functionally-typed closure creating a temporary variable.
func squared(_ a: Float) -> Float {
var tmp = a
square(&tmp)
return tmp
}
let squaredRecomputing = makeRecomputedInGradient(squared)
If there are reasons why this doesn't work for your use case (performance?), please share more details!
I think there may be better solutions than overloading differentiation APIs to accept @differentiable
functions with inout
parameters, which is not really ideal.
Actually this works perfectly! My initial design had the function call inside a loop, so I went to inout for performance. Later my design changed to enclose the loop, and I just didn't think to get rid of the inout approach. I realized while reading your reply that I can just change back to a functional version instead of inout. Apologies, bit of a trivial thing to realize given the effort of your answer. I was trying so hard to fix the problem I got tunnel vision :)
Thanks!
In practice, I think there are fewer uses for "directly applying differential operators to functions with inout parameters" than "applying differential operators to functionally-typed functions that internally call mutating functions".
I think the latter results in code that is easier to understand, too.
Agree
I'm glad you found a solution 🙂
Thanks again for the question! We can turn the answers here into documentation sometime.
I'm struggling to make an inout version of the makeRecomputedInGradient function (I have an inout function with only one argument, no outputs (except the inout of course)).
I started with the signature
I'm stuck because I don't know how do the equivalent of call
pullback(at:in:)
(since it doesn't appear to take inout functions)Along similar lines, shouldn't I be able to do something like this:
as is, it looks like `gradient(of:) doesn't take inout functions either.
Thanks, it's all slowly making sense, but I'm not quite all the way!