Open yuanqing-wang opened 3 years ago
Thanks for the question! Take a look at these docs if you haven't already: jax.custom_vjp
with nondiff_argnums
, and the custom derivative rules notebook in general.
Actually, both jax.interpreters.ad.defvjp_all
and jax.interpreters.ad.defvjp
are internal-only functions, and are also deprecated! So I don't recommend using those.
Does that answer your question?
Thanks @mattjj ! In that case, what would be the recommended way to define the vjp
for user-defined primitives? Is it jax.custom_vjp
as well?
Ah, I wasn't sure if you meant that you literally had a core.Primitive
instance, or if you just meant it more loosely as 'an autodiff primitive', which a function with a custom_jvp
or custom_vjp
rule is, I'd say.
If you have your own core.Primitive
, the best thing to do is to define a JVP rule and a separate transposition rule. Internally, JAX doesn't really have VJP rules, but instead it derives VJPs automatically by composing JVPs with partial evaluation and transposition. (Transposition rules only need to be defined for linear primitives, or more precisely for primitives that are linear with respect to some of their inputs.)
It's true that ad.defvjp
and ad.defvjp_all
(the former is a thin wrapper around the latter) let you define VJP rules directly for primitives, but if you look at the implementation of defvjp_all
you can see it's actually setting up a JVP rule and a transpose rule under-the-hood. The PR message in #636 has details about how that's set up. However, this trick has limitations: it's awkward, and it breaks forward-mode autodiff.
Because of those limitations to ad.defvjp
, and because user-defined Primitives are rare (and usually only added by experts), we hope to keep things simple and just advise folks who write their own Primitives to think in terms of separate JVP and transposition rules. We've yet to see a case where that doesn't work out.
What do you think about defining separate JVP and transpose rules? A VJP rule is just a composition of those two things, so if you have a VJP rule worked out, you should be able to separate it into those two steps.
Yes I do have a core.Primitive
instance.
But I'm actually not sure whether a core.Primitive
is the best way to achieve what I have in mind:
vjp
and jvp
calculations.Hm, because you have an external un-traceable function, that does sound like a Primitive is appropriate, though perhaps you only need it for that one particular function call. A Primitive basically means "don't trace inside here."
(I haven't dug into those links yet.)
cc @jakevdp since this is related to sparsity
Hi @mattjj, if I used custom JVP for a core.Primitive
, is there a way to define nondiff_argnums
then?
Hi @mattjj sorry I still haven't figure out how to do this: is it possible to specify nondiff_argnums
for JVP for core.Primitive
?
Hi, is there a way one could specify
nondiff_argnums
for the custom vjp of a primitive? Like usingjax.interpreters.ad.defvjp_all
orjax.interpreters.ad.defvjp
?