google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.13k stars 645 forks source link

Add HOWTO: Custom gradients in Flax #897

Closed marcvanzee closed 2 years ago

marcvanzee commented 3 years ago

The PixelCNN has an example of using custom gradients: pixelcnn.py.

The Autodiff Cookbook by the JAX team is very useful for background reading as well.

marcvanzee commented 2 years ago

Closing this since there has been no interest in this HOWTO.