jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.48k stars 2.8k forks source link

jacrev and jacfwd usage example #47

Closed absudabsu closed 5 years ago

absudabsu commented 5 years ago

Is there a simple example of how to use jacrev and jacfwd? There's currently no useful docstring. Some usage details would be helpful.

For example: (1) when calling grad(fun), does it differentiate w.r.t. the first input argument? (2) how to use jacrev(func), when func takes multiple inputs? (e.g. differentiate w.r.t. 3rd input variable)

Thanks.

mattjj commented 5 years ago

You're right, we need to add docstrings. Things work roughly like in Autograd, and in particular both jacrev and jacfwd work roughly like Autograd's jacobian (which itself uses reverse-mode) though a bit less flexibly.

As a stop-gap, here are some quick answers to the questions you raised:

grad(fun, argnums) differentiates fun with respect to positional argument(s) argnums. The argnums argument defaults to 0 and can be an integer or tuple of integers to indicate which positional argument(s) with respect to which to differentiate. For example:

def f(x, y):
  return 2 * x * y

grad(f)(3., 4.)  # 8.
grad(f, 0)(3., 4.)  # 8.
grad(f, 1)(3., 4.)  # 6
grad(f, (0, 1))(3., 4.)  # (8., 6.)

grad can only be applied to a scalar-output function.

jacfwd and jacrev currently only accept single-argument functions, and so if you have a multiple argument function you want to differentiate, you'll need to wrap it in another function by hand. For example:

jacrev(lambda x: f(x, 4.))(3.)  # 8. 
jacrev(lambda y: f(3., y))(4.)  # 6.

jacfwd and jacrev can only be applied to functions with array (or scalar) inputs and array (or scalar) outputs. They should always give you the same answer (up to numerical precision), and only differ in speed. jacfwd uses forward-mode autodiff and so is better for "tall" Jacobian matrices, while jacrev uses reverse-mode and so is better for "wide" Jacobian matrices (with a grad corresponding to the special case of a Jacobian matrix consisting of a single row).

absudabsu commented 5 years ago

Thanks, this is helpful! Hopefully this will help some new folks.

As a sidenote: is there a technical reason why jacrev can't support selection via argnums ?

mattjj commented 5 years ago

Not that I know of. We just haven't written that wrapper.

mattjj commented 5 years ago

Since #51 is about docstrings for all api.py functions, let's close this issue and use #51 to track our progress on documenting things. (Please reopen if that's not a good plan.)