Open elcorto opened 3 years ago
The example code should follow test_jax.py
and implement the same operations if possible to enable easy comparison of libraries.
As of torch 2.0, there is torch.func
(formerly functorch
) which implements a subset of the jax API (e.g. torch.func.grad
behaves like jax.grad
). There is also support for using the torch.func
API with custom derivatives, using torch.autograd.Function
, even though it seems more complex to set up (i.e. things like ctx.save_for_backward
).
torch.autograd.Function
derived class -> defineforward
andbackward
methodstorch.nn.Module.register_full_backward_hook