google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.13k stars 2.67k forks source link

Advanced Autodiff Cookbook? #12076

Open OhadRubin opened 1 year ago

OhadRubin commented 1 year ago

In the notebooks on the documentation here: https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html

The author describes topics that he would like to showcase in a future Autodiff cookbook:


1. Gauss-Newton Vector Products, linearizing once

2. Custom VJPs and JVPs

3. Efficient derivatives at fixed-points

4. Estimating the trace of a Hessian using random Hessian-vector products.

5. Forward-mode autodiff using only reverse-mode autodiff.

6. Taking derivatives with respect to custom data types.

7. Checkpointing (binomial checkpointing for efficient reverse-mode, not model snapshotting).

8. Optimizing VJPs with Jacobian pre-accumulation.

What is the status of this possible future tutorial?

hawkinsp commented 1 year ago

We haven't written it yet :-)

If you're interested in any of these topics, would you be interested in helping to extending the current autodiff cookbook?

Some of these topics do have their own documentation already, e.g., Custom JVPs/VJPs.

mattjj commented 1 year ago

C'mon, it's only been like 4 years, don't rush us! 😛

We even have an expanded list of stuff to talk about, like jet.

As another example of "things that are documented but not collected into a Cookbook Part 2", we wrote a bit about fixed-point (and ODE) differentiation in this NeurIPS 2020 tutorial.

moskomule commented 1 month ago

I'm waiting for the updates for another two years 👀