jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.99k stars 2.75k forks source link

Advanced Autodiff Cookbook? #12076

Open OhadRubin opened 2 years ago

OhadRubin commented 2 years ago

In the notebooks on the documentation here: https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html

The author describes topics that he would like to showcase in a future Autodiff cookbook:


1. Gauss-Newton Vector Products, linearizing once

2. Custom VJPs and JVPs

3. Efficient derivatives at fixed-points

4. Estimating the trace of a Hessian using random Hessian-vector products.

5. Forward-mode autodiff using only reverse-mode autodiff.

6. Taking derivatives with respect to custom data types.

7. Checkpointing (binomial checkpointing for efficient reverse-mode, not model snapshotting).

8. Optimizing VJPs with Jacobian pre-accumulation.

What is the status of this possible future tutorial?

hawkinsp commented 2 years ago

We haven't written it yet :-)

If you're interested in any of these topics, would you be interested in helping to extending the current autodiff cookbook?

Some of these topics do have their own documentation already, e.g., Custom JVPs/VJPs.

mattjj commented 2 years ago

C'mon, it's only been like 4 years, don't rush us! 😛

We even have an expanded list of stuff to talk about, like jet.

As another example of "things that are documented but not collected into a Cookbook Part 2", we wrote a bit about fixed-point (and ODE) differentiation in this NeurIPS 2020 tutorial.

moskomule commented 3 months ago

I'm waiting for the updates for another two years 👀