JuliaDiff / AbstractDifferentiation.jl

An abstract interface for automatic differentiation.
https://juliadiff.org/AbstractDifferentiation.jl/
MIT License
137 stars 18 forks source link

Use multiple arguments instead of a tuple for pushforward and pullback function? #53

Open devmotion opened 2 years ago

devmotion commented 2 years ago

It seems annoying that the pushforward and pullback function accept tuples of co-tangents instead of multiple arguments. Is there a compelling reason for doing so or was this a design decision that could be changed? In my opinion the main annoyance is that one has to handle the case of tuples of length 1 in a special way (as e.g. in https://github.com/JuliaDiff/AbstractDifferentiation.jl/pull/51) (it also makes it impossible to work with actual single-argument functions that take a tuple as only argument but maybe this is not needed anyway). Arguably it is also cleaner to provide multiple arguments as, well, multiple arguments instead of a tuple.

mohamed82008 commented 2 years ago

Yes I think this can be considered along with #35.

sethaxen commented 2 years ago

It seems annoying that the pushforward and pullback function accept tuples of co-tangents instead of multiple arguments. Is there a compelling reason for doing so or was this a design decision that could be changed?

While Julia functions may take multiple inputs, no Julia function returns multiple outputs. Instead, they might return a tuple of outputs. The (co)tangent of a tuple is like a tuple itself. FWIW, this is consistent with how ChainRules behaves, hence why in ChainRules, it would be represented as a Tangent, and here it would be represented as a tuple. One could make the case that since AD.jl supports only functions whose inputs and outputs are arrays, then if such a function returns a tuple it can only be interpreted as multiple outputs, but that would be inconsistent at least with ChainRules and Zygote.

it also makes it impossible to work with actual single-argument functions that take a tuple as only argument but maybe this is not needed anyway

I don't think function with tuple inputs would be supported anyways.

devmotion commented 2 years ago

While Julia functions may take multiple inputs, no Julia function returns multiple outputs. Instead, they might return a tuple of outputs. The (co)tangent of a tuple is like a tuple itself. FWIW, this is consistent with how ChainRules behaves, hence why in ChainRules, it would be represented as a Tangent, and here it would be represented as a tuple. One could make the case that since AD.jl supports only functions whose inputs and outputs are arrays, then if such a function returns a tuple it can only be interpreted as multiple outputs, but that would be inconsistent at least with ChainRules and Zygote.

Sure, multiple outputs are in fact just a tuple of outputs. But it does not necessarily mean that we have to use a tuple as input to the pullback and pushforward function.

The current design is also not completely consistent with ChainRules: In ChainRules one does not have to consider tuples of co-tangents of length 1 - the pullback function of a function with a single output just takes a single co-tangent without wrapping it as a tuple. Neglecting/not supporting tuples of length 1 would already solve the special case in #51, even if we stick with tuples in case of multiple outputs.

sethaxen commented 2 years ago

I think in general AD.jl has a funny relationship with inputs and outputs. Like gradient for a single input returns a tuple, and hessian only supports single inputs and yet still returns a tuple. IMO this should be changed.

The pushforward of a function (talking about the actual pushforward, not the fusion of the pushforward and primal that frule encodes) should be structured the same as the primal in terms of inputs and outputs. The pullback is the adjoint of the pushforward and vice versa, so a useful check of consistency is whether the rules we choose are symmetric.

i.e., these rules would maintain this symmetry, and perhaps they make sense:

This is almost consistent with ChainRules, the key differences being that 1) in ChainRules, the function is treated as an argument, so there are no single-argument functions (or at least, I don't know of any examples where a rule is defined for a 0-argument function), hence all pullbacks return tuples and 2) a function might actually return a tuple directly, so it's not safe to interpret a tuple return value as being multiple outputs.

gdalle commented 1 year ago

I could try to give this a shot once #93 is merged

gdalle commented 10 months ago

Starting to work on this and I'm wondering what to do with the lazy derivatives? Only allow them for a single input / output? It's a bit counterintuitive to apply matrix multiplication on tuple anyway