Open ricardoV94 opened 1 year ago
I think it would be great if we could improve this! But if we change names, maybe we could also change the design a little bit at the same time?
I general I really like the jax api for jacobian products.
So for R_op for instance, I think we should return a tuple (val, grads)
, so that we can use ops that compute both the gradient and value at the same time. The lack of that is really awkward for ode models etc. And for existing once we can just return the old value.
I think the same is also true for L_op. It is useful to be able to return a new graph for the forward values, if you know that you will also want to compute the backward ops. (For ode's for instance you have to do some setup in the forward solver when you later want to compute the backward grads).
About the names: We could use the jax names, which has the advantage that people might already know them? I always get hopelessly confused about those and have to stop and think each time which one is which though.
Maybe instead push_forward
and pull_back
?
So I guess something like:
# default impl based on current API
def push_forward(self, primals, tangents):
return (self.R_op(primals, tangents), self(primals))
# (or pass in outputs?)
def pull_back(self, inputs, cotangents):
outputs = self(*inputs)
return (outputs, self.L_op(inputs, outputs, cotangents))
Are back_prop
, fwd_prop
very wrong?
On the new API
I don't see the advantage of having a method that returns both types of output by default in PyTensor. That's only useful if you are sure you will need both in the final graph but that's not given (even if a user were to call a value_and_grad
, it might not want to compile both)
Unlike JAX we can't introspect the actual perform function and remove unused computations. So to achieve the same you have to put it out in the Pytensor graph.
I still think the approach where you make your Op return value and grads as make_node
outputs and selectively hide them / reuse them when the gradient is requested is fine (see https://www.pymc.io/projects/examples/en/latest/case_studies/wrapping_jax_function.html#bonus-using-a-single-op-that-can-compute-its-own-gradients for an example) ((PS: Should update that example to use L_op
, so that self(*inputs)
is not needed))
You would probably want a rewrite to replace the node by one that only computes the value if the gradients are never used, but the same would be the case if you followed JAX API, right?
Something like this:
@register_specialize
@node_rewriter(tracks=[ValueAndGradOp])
def replace_if_outputs_not_used(fgraph, node):
# Assume we have a node with 2 inputs, which by default returns value and 2 grads
# In this graph only the value outputs are being used
# So we replace it by a version that only computes that
used_outputs = [len(fgraph.clients(out)) for out in node.outputs]
if used_outputs == [True, False, False]:
return [value_op(*node.inputs), None, None]
return None
...
out, *_ = value_and_grad_op(*inputs)
fn = pytensor.function(inputs, out)
The other option is to try and merge value and grad nodes into the "super_op”, but that's harder to do.
PS: This is very similar to what the MaxAndArgmax
Op does.
Are back_prop, fwd_prop very wrong?
I've never seen those before... I guess I like pullback and pushforward, because those are simply the names of those functions in math. :-)
I don't see the advantage of having a method that returns both types of output by default in PyTensor. That's only useful if you are sure you will need both in the final graph but that's not given (even if a user were to call a value_and_grad, it might not want to compile both)
I don't think that's the way to look at it. It's more like: "Given that you are computing the derivatives, we'd now have a faster way to compute the values in case you want to have both". If you don't need the new way to compute the original function output, you can just ignore it.
It is true that you could usually try to find rewrites that make it so that computing both doesn't do more work than is necessary. But those rewrites add a lot of work and make things more complicated.
Let's use Loop
as an example.
do:
go_on, state = update(state)
while go_on
The pushforward should be a loop again, and at first we could try to only compute the derivatives:
do:
go_on, d_state = d_update(d_state)
while go_on
but this doesn't work, because d_update
will not only need d_state as input, but also state
(at the very least to compute the same go_on
values as the original loop. So more like this:
do:
go_on, state, d_state = d_update(state, d_state)
while go_on
But now we already compute the output state. So if we want to have both, it would be much easier to just use this state output instead of the original one.
It is true that we could just use the original state output and the d_state
output of the derivative, and try to come up with rewrites that notice that we are computing state
twice. Some loop-fusion pass might for instance do this. But why go to all that trouble if we had the correct thing all along?
I think in general it is much easier for rewrites to notice that some computation isn't used and throw it away, than to try and figure out that two different computations are the same thing.
The fuse-elemwise rewrites are I think actually exactly this problem. If we wanted we could investigate if we would like to change the elemwise gradient implementations so that they produce the fused elemwise ops directly.
The loop example is very relevant because it's an Op that "hides" a graph inside it. So fusing redundant computations becomes harder.
If you now hide the value and gradient graphs inside more complicated Ops you are doing something similar at the level of what used to be pure Ops. Avoiding repeated computations becomes now hard in general (JAX doesn't suffer this because it traces through the inner graph).
If I have exp(x) used somewhere else and inside the ExpAndGrad(x) you would probably want to get rid of the first, but now you are doing this to possibly every Op?
The ValueAndGrad makes a lot of sense for an Op that can't be represented easily by pure Ops (or one that hides a complex inner graph) but I don't see why it makes sense for simple Ops, and hence by default.
Another option is to make gradient a dummy Op so that some rewrites can be done before we actually unroll the gradient graphs (when we already know exactly what will be needed as output). This has been brought up a few times before.
Where would something like ExpAndGrad ever come from?
class Exp(Op):
def push_forward(self, primals, tangents):
(x,) = primals
(d_x,) = tangents
y = self(x)
return y, [y * d_x]
# Or probably even a bit better to give a more flexibility
class Exp(Op):
def push_forward(self, node, tangents, need_output_tangent: List[bool]):
(x,) = node.inputs
(y,) = node.outputs
(d_x,) = tangents
(need_out_tangent,) = need_output_tangent
if need_out_tangent:
return y, [y * d_x]
else:
return y, [None]
need_output_tangent
contains bools indicating if we are interested in gradients with respect to each output. Most ops could probably just ignore that, but I think that can come in handy in a couple of cases.
I really don't see where that would make rewrites more difficult. For most ops it would just continue to do the same thing, but in cases where that doesn't make sense (ode, loop, if-else, integrate?, elemwise if we are adventurous) we make things much simpler and avoid having to play pretty tricky rewrite games to get decent performance.
Then I don't understand, what does that do? If you are asking the gradient wrt to an input you already have the output (it's what you passed as input to the pt.grad
function)
The L_op
even receives the node outputs explictly, so it already "has them"
I was assuming you meant to combine operations inside the "perform" method.
This is what it looks like now
class Exp(Op):
def push_forward(self, primals, outputs, tangents):
(x,) = primals
(d_x,) = tangents
(y,) = outputs
return [y * d_x]
The problem with the current way is that when we compute the tangents of the output we sometimes have to generate a second way to compute the output (see the loop example, how would we compute the output tangents without also re-computing the output in that loop as well?). The old way to compute the output is still valid, but using that would often be wasteful (which sometimes can be fixed using tricky rewrites). So in the (very common) case that you actually want to compute both the output and the output tangents, it is faster to use this new output graph. And the old api requires us to just throw that away, the alternative api allows us to use it if we like.
So for the loop it might look something like this:
class Loop(Op):
def push_forward(self, node, tangents, need_output_tangent: List[bool]):
push_forward_update = push_forward(self.update, ...not for the go_on arg...) # not implement yet...
loop = Loop(push_forward_update)(*node.inputs, *tangents)
return loop.outputs[:len(node.inputs)], loop.outputs[len(node.inputs):]
Okay so you actually mean stuff like Loops or Blackbox Ops, not vanilla Ops like Elemwise? That's what I didn't understand.
It could make sense because it is indeed easier to prune unused paths of the graph then merge redundant paths (although we do have rewrites for that with scan). But you can only do that for Ops that have an inner graph we can inspect. We cannot inspect the inner graph of Elemwise (other than Composite) so I wouldn't become adventurous.
And is this relevant for L_op
as well?
The downside is that the gradient methods would become harder to write. There is some value to simplicity :)
My only issue with push_forward and pull_backward is just that they are a bit verbose, still better than L_op
and R_op
, so we can go with those if nobody has other suggestions
Are pushforward and pullback widely used in the auto-diff literature? I've personally not seen them before, and this paper, explicitly distinguishes between them. It seems to me that names like forward_diff and reverse_diff (or *_prop) are closer to describing what is going on in pytensor (as I understand it), and are more commonly understood by anyone working with other tensor packages.
FWIW, the fact that grad
exists along side L_op
and R_op
is much more confusing than the names L_op
and R_op
themselves.
The nomenclature is certainly a bit messy at the moment about this whole thing.
I don't think "pushforward" and "pullback" are that common in the autodiff literature right now, but they have been names for what we are doing for a long time in differential geometry.
"forward_diff" and "reverse_diff" seem to be also overloaded terms. Does this refer to computing the whole jacobian as in the paper you linked? Or jacobian-vector-products? Or in the case of "reverse_diff" only computing gradients of functions to a scalar? If we are talking about jacobians, that also actually sounds like everything has to be a 1D vector, but this clearly isn't the case. Saying what r_op is computing without using terms from diffgeo and without assuming everything is a 1d vector is actually pretty tricky.
I think vector and matrix calculus as often taught right now is just not the right tool to understand and talk about the functions we are working with. In contrast diffgeo has exactly the tools we need (and quite a bit more that maybe we could also use I guess).
I have the impression that the correspondence between these concept in diffgeo and autodiff are only getting properly explored right now (for instance in the paper you just linked), but I think this will become more and more common as time goes on, and people figure out how helpful it is to understand what we are doing. jax already is using a lot of that language to describe jvp
and vjp
("covector", "tangent"). We could just use the same names as jax does (jvp and vjp), but I really don't like those, even after using them for quite some time I always have to think about which is which.
The diffgeo perspective also solves the problem of understanding what the arguments to those functions are. If you look at the old theano docs, it really is quite hard to understand even what the input to l_op
or grads
actually are supposed to be. But with the diffgeo view that becomes really easy (assuming some diffgeo knowledge, so not exactly easy in practice): The inputs of the pushforward or forward mode autodiff are tangent vectors, the input to reverse mode autodiff or the pullback are covectors.
So I guess we have a situation where some math field has the exact concepts and names for the operations we are doing, but that field isn't known to that many people working with it. Do we then want to use the math- names, or something that might be less precise but more widely known? I guess I favor just going with the math name (after all, the library is called pytensor), but I can certainly see arguments against that.
Right now we have
grad
,L_op
,R_op
.Deprecate
grad
in favor ofL_op
:grad
is exactly the same asL_op
except it doesn't have access to the outputs of the node that is being differentiated.https://github.com/pymc-devs/pytensor/blob/24b67a860b6a3d38e9f23505800c4d2af0aee852/pytensor/graph/op.py#L366-L393
L_op
allows one to reuse the same output when it's needed in the gradient, which means there is one less node to be merged during compilation. This is mostly relevant for nodes that are costly to merge such as Scan (see https://github.com/pymc-devs/pytensor/commit/0f5a06d7f7bd8480d52d1d2b754bffd7abd14f21).It also saves time spent on
make_node
(e.g., inferring static type shapes). In the Scalar Ops it's used everywhere to quickly check if the output types are discrete (see https://github.com/pymc-devs/pytensor/commit/fd628c5a74adbfcdc72bf7362ffab07a7b7c0cd6). There are some opportunities still missing, for example, the gradient ofExp
:https://github.com/pymc-devs/pytensor/blob/24b67a860b6a3d38e9f23505800c4d2af0aee852/pytensor/scalar/basic.py#L3096-L3107
Could instead return
(gz * outputs[0],)
More importantly for this issue, I think we should deprecate
grad
completely, since everything can be equally well done withL_op
.Rename
L_op
andR_op
?The names are pretty non-intuitive, and I don't think they are used in any other auto-diff libraries. The equivalents in JAX are
vjp
andjvp
(you can find direct translation in https://www.pymc-labs.io/blog-posts/jax-functions-in-pymc-3-quick-examples/)Other suggestions were discussed some time ago by Theano devs here: https://groups.google.com/g/theano-dev/c/8-z2C59rmQk/m/gm432ifVAg0J?pli=1
Remove
R_op
in favor of double application ofL_op
(or make it a default fallback)There was some fanfare sometime ago about
R_op
being completely redundant in a framework with dead code elimination: https://github.com/Theano/Theano/issues/6035That thread suggests also the double
L_op
may generate more efficient graphs in some cases (because most of our rewrites target the type of graphs generated byL_op
?)It probably makes sense to retain the
R_op
for cases where we/users know that's the best approach but perhaps default/revert to doubleL_op
otherwise. Stale PRs that never quite got into Theano:https://github.com/Theano/Theano/pull/6400 https://github.com/Theano/Theano/pull/6037