hzdr / autodiff101

An introduction to Automatic Differentiation with theory and code examples.
BSD 3-Clause "New" or "Revised" License
5 stars 1 forks source link

Add examples for custom derivative definitions in pytorch #1

Open elcorto opened 3 years ago

elcorto commented 3 years ago
elcorto commented 8 months ago

The example code should follow test_jax.py and implement the same operations if possible to enable easy comparison of libraries.

elcorto commented 2 months ago

As of torch 2.0, there is torch.func (formerly functorch) which implements a subset of the jax API (e.g. torch.func.grad behaves like jax.grad). There is also support for using the torch.func API with custom derivatives, using torch.autograd.Function, even though it seems more complex to set up (i.e. things like ctx.save_for_backward).