Closed absudabsu closed 5 years ago
You're right, we need to add docstrings. Things work roughly like in Autograd, and in particular both jacrev
and jacfwd
work roughly like Autograd's jacobian
(which itself uses reverse-mode) though a bit less flexibly.
As a stop-gap, here are some quick answers to the questions you raised:
grad(fun, argnums)
differentiates fun
with respect to positional argument(s) argnums
. The argnums
argument defaults to 0 and can be an integer or tuple of integers to indicate which positional argument(s) with respect to which to differentiate. For example:
def f(x, y):
return 2 * x * y
grad(f)(3., 4.) # 8.
grad(f, 0)(3., 4.) # 8.
grad(f, 1)(3., 4.) # 6
grad(f, (0, 1))(3., 4.) # (8., 6.)
grad
can only be applied to a scalar-output function.
jacfwd
and jacrev
currently only accept single-argument functions, and so if you have a multiple argument function you want to differentiate, you'll need to wrap it in another function by hand. For example:
jacrev(lambda x: f(x, 4.))(3.) # 8.
jacrev(lambda y: f(3., y))(4.) # 6.
jacfwd
and jacrev
can only be applied to functions with array (or scalar) inputs and array (or scalar) outputs. They should always give you the same answer (up to numerical precision), and only differ in speed. jacfwd
uses forward-mode autodiff and so is better for "tall" Jacobian matrices, while jacrev
uses reverse-mode and so is better for "wide" Jacobian matrices (with a grad
corresponding to the special case of a Jacobian matrix consisting of a single row).
Thanks, this is helpful! Hopefully this will help some new folks.
As a sidenote: is there a technical reason why jacrev
can't support selection via argnums
?
Not that I know of. We just haven't written that wrapper.
Since #51 is about docstrings for all api.py functions, let's close this issue and use #51 to track our progress on documenting things. (Please reopen if that's not a good plan.)
Is there a simple example of how to use jacrev and jacfwd? There's currently no useful docstring. Some usage details would be helpful.
For example: (1) when calling grad(fun), does it differentiate w.r.t. the first input argument? (2) how to use jacrev(func), when func takes multiple inputs? (e.g. differentiate w.r.t. 3rd input variable)
Thanks.