Closed porterchild closed 4 years ago
Some hasty answers, feel free to ask for clarification!
TF-1250 tracks differentiation with respect to multiple results.
Cannot differentiate functions with both an 'inout' parameter and a result
is an artificial limitation that should be lifted within this month.
I've tried adding an inout argument to
loss
, but I've failed to exclude it from differentiation with@differentiable(wrt: (only one of the two))
.
This is probably because you called loss
in a way such that both the result and inout
parameter are active.
It doesn't matter that you registered a @differentiable
attribute with respect to only one parameter - activity analysis tells differentiation that it needs derivatives with respect to both an inout
parameter and result. If you don't care about derivatives of one of the two, you need to indicate this at the caller by using withoutDerivative(at:)
.
This is something we can clear up in Differentiable Swift usage documentation:
you cannot "exclude function arguments from differentiation" by declaring a @differentiable
attribute with some subset wrt
parameters.
@differentiable
attribute just declares a function as being differentiable with respect to some parameters, it doesn't exclude the function from being differentiable with respect to more parameters. Something we've considered is a @noDerivative
function parameter attribute, which would do this exclusion and eliminates the need for callers to use withoutDerivative(at:)
.
It's where I have a loss function that I'm feeding some stateful data structure, and I want the state of the data as well as the cost at the end. However, the output of the loss function must be a single scalar, so I can't return the state as well. I've been solving it by saving the state off to a global variable (right before
return cost
) until now, but I need a better solution.
I think you can always avoid mutation via functionalization: make loss
return a Differentiable
-conforming struct of cost
and state, no mutation.
This is invasive though - TF-1250 will support differentiation wrt inout
parameter and result, which may address your use case without code changes.
I ought to stop closing issues upon answering. Feel free to close if the answers makes sense!
Something we've considered is a @noDerivative function parameter attribute, which would do this exclusion and eliminates the need for callers to use withoutDerivative(at:).
I think that would be perfect. I've tried that intuitively a few times without knowing it didn't exist yet.
This is probably because you called loss in a way such that both the result and inout parameter are active.
Actually, the error pops up as soon as I add the inout parameter to the function signature, without actually using it in the function at all.
I think you can always avoid mutation via functionalization: make loss return a Differentiable-conforming struct of cost and state, no mutation.
Doesn't a function have to return a single FloatingPoint
in order to be able to call grad(of:)
on it?
Either way, reading your answer I realized I can wrap my loss
with a non-differentiable function that has the same arguments and gets the gradient from calling loss
and then returns whatever it wants. Thanks!
Actually, the error pops up as soon as I add the inout parameter to the function signature, without actually using it in the function at all.
This make sense before differentiation supports multiple results (TF-1250). As soon as you add an inout
parameter, the function has multiple semantic results and thus cannot currently be differentiated.
Doesn't a function have to return a single
FloatingPoint
in order to be able to callgrad(of:)
on it?
True. You can always use pullback(at:in:)
instead, passing in a Output.TangentVector(...)
seed.
gradient
is just a special case of pullback
for scalar-output original functions, applying pullback(at: ..., in: ...)(1)
.
gradient
is just a special case ofpullback
for scalar-output original functions, applyingpullback(at: ..., in: ...)(1)
.
Thank you! That connected a few things in my brain. I still look forward to the day when I stop being surprised by trivial implications of my elementary calculus understanding :)
You can always use
pullback(at:in:)
instead, passing in aOutput.TangentVector(...)
seed.
So I could pass Output.TangentVector(cost: 1, state: State(zeroes))
to pullback
and I would get the same gradient as calling grad(of:)
where loss()
only returned the cost, correct?
Thank you! That connected a few things in my brain.
Sure thing! Check out the Wikipedia page on "Gradient":
In vector calculus, the gradient of a scalar-valued differentiable function
f
of several variables,f: R^n -> R
, is the vector field [...]∇f: R^n -> R^n
, whose value at a pointp
is the vector whose components are the partial derivatives off
atp
: ...
The differentiable programming manifesto section on "pullback-producing differential operators" shows how gradient
is implement in terms of pullback
.
So I could pass
Output.TangentVector(cost: 1, state: State(zeroes))
topullback
and I would get the same gradient as callinggrad(of:)
whereloss()
only returned the cost, correct?
Yes.
Yeah it makes perfect sense mathematically, I guess I just hadn't connected some things yet.
When I learned calculus, it was through the lens of f(x) = x
looking algebraic functions where 'taking the derivative' meant an algebraic manipulation of the original equation, which when found, you could run forward as its own separate equation for f'(x)
. So in general, thinking about derivatives instead as a vector propagating backwards through a series of locally linear operations in the original equation has been a bit of an adjustment for me. It totally makes sense why the default is to think about it that way when implementing autodiff, it's just not the way I first learned it. So every so often I'll make a connection in the vector-propagation-interpretation that already made sense to me in the high-level-algebra-interpretation.
I obviously didn't understand it very deeply even in the way I first learned it, otherwise the vector-propagation-interpretation would have immediately made perfect sense.
I'm 60% sure anything I just said actually makes sense haha
When I learned calculus, it was through the lens of
f(x) = x
looking algebraic functions where 'taking the derivative' meant an algebraic manipulation of the original equation, which when found, you could run forward as its own separate equation forf'(x)
.
Differentiation as "an algebraic manipulation of the original equation" sounds like symbolic differentiation à la Wolfram Alpha. Automatic differentiation is in fact quite similar to symbolic differentiation (section 2.1 of "Demystifying Differentiable Programming").
So in general, thinking about derivatives instead as a vector propagating backwards through a series of locally linear operations in the original equation has been a bit of an adjustment for me.
The chain rule of differentiation gives us a way to compute derivatives of composed functions. Here's a short chain:
Imagine you have a long chain with a bunch of dy_i/dy_{i-1}
factors. We can evaluate the chain via:
dy_1/dx
). This is forward-mode differentiation.dz/dy_n
). This is reverse-mode differentiation (backpropagation).In fact, the association (grouping of parentheses) of multiplication affects efficiency. This is the same problem as matrix chain multiplication. The optimal association may be mixed-mode (a mix of forward-mode and reverse-mode differentiation).
I'm not sure what introductory resources best explain these ideas. I learned quite a bit from Conal Elliott's "The simple essence of automatic differentiation", but I'm still digesting it after months.
Differentiation as "an algebraic manipulation of the original equation" sounds like symbolic differentiation à la Wolfram Alpha
Yes, exactly what I meant, forgot the term.
Thanks for contextualizing my hazy thoughts, I've got a lot to digest here!
I've gotten stuck at a similar junction a few times. It's where I have a loss function that I'm feeding some stateful data structure, and I want the state of the data as well as the cost at the end. However, the output of the loss function must be a single scalar, so I can't return the state as well. I've been solving it by saving the state off to a global variable (right before
return cost
) until now, but I need a better solution. I've tried adding an inout argument toloss
, but I've failed to exclude it from differentiation with@differentiable(wrt: (only one of the two))
. I still get the compiler errorCannot differentiate functions with both an 'inout' parameter and a result
So generally, what exactly is the pattern to include and exclude arguments from differentiation, and specifically, is there a better way to solve my state problem than a global variable?
Thanks!