jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.12k stars 2.76k forks source link

`nondiff_argnums` for `custom_vjp` primitives? #4918

Open yuanqing-wang opened 3 years ago

yuanqing-wang commented 3 years ago

Hi, is there a way one could specify nondiff_argnums for the custom vjp of a primitive? Like using jax.interpreters.ad.defvjp_all or jax.interpreters.ad.defvjp?

mattjj commented 3 years ago

Thanks for the question! Take a look at these docs if you haven't already: jax.custom_vjp with nondiff_argnums, and the custom derivative rules notebook in general.

Actually, both jax.interpreters.ad.defvjp_all and jax.interpreters.ad.defvjp are internal-only functions, and are also deprecated! So I don't recommend using those.

Does that answer your question?

yuanqing-wang commented 3 years ago

Thanks @mattjj ! In that case, what would be the recommended way to define the vjp for user-defined primitives? Is it jax.custom_vjp as well?

mattjj commented 3 years ago

Ah, I wasn't sure if you meant that you literally had a core.Primitive instance, or if you just meant it more loosely as 'an autodiff primitive', which a function with a custom_jvp or custom_vjp rule is, I'd say.

If you have your own core.Primitive, the best thing to do is to define a JVP rule and a separate transposition rule. Internally, JAX doesn't really have VJP rules, but instead it derives VJPs automatically by composing JVPs with partial evaluation and transposition. (Transposition rules only need to be defined for linear primitives, or more precisely for primitives that are linear with respect to some of their inputs.)

It's true that ad.defvjp and ad.defvjp_all (the former is a thin wrapper around the latter) let you define VJP rules directly for primitives, but if you look at the implementation of defvjp_all you can see it's actually setting up a JVP rule and a transpose rule under-the-hood. The PR message in #636 has details about how that's set up. However, this trick has limitations: it's awkward, and it breaks forward-mode autodiff.

Because of those limitations to ad.defvjp, and because user-defined Primitives are rare (and usually only added by experts), we hope to keep things simple and just advise folks who write their own Primitives to think in terms of separate JVP and transposition rules. We've yet to see a case where that doesn't work out.

What do you think about defining separate JVP and transpose rules? A VJP rule is just a composition of those two things, so if you have a VJP rule worked out, you should be able to separate it into those two steps.

yuanqing-wang commented 3 years ago

Yes I do have a core.Primitive instance.

But I'm actually not sure whether a core.Primitive is the best way to achieve what I have in mind:

mattjj commented 3 years ago

Hm, because you have an external un-traceable function, that does sound like a Primitive is appropriate, though perhaps you only need it for that one particular function call. A Primitive basically means "don't trace inside here."

(I haven't dug into those links yet.)

cc @jakevdp since this is related to sparsity

yuanqing-wang commented 3 years ago

Hi @mattjj, if I used custom JVP for a core.Primitive, is there a way to define nondiff_argnums then?

yuanqing-wang commented 3 years ago

Hi @mattjj sorry I still haven't figure out how to do this: is it possible to specify nondiff_argnums for JVP for core.Primitive?