pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

Support for custom gradient functions #1031

Closed esteveste closed 1 year ago

esteveste commented 1 year ago

Does functorch has support/plans for custom differentiation rules for python functions like jax.custom_jvp and jax.custom_vjp?

zou3519 commented 1 year ago

Not yet, but we have plans to support it sometime by the end of the year.

zou3519 commented 1 year ago

@esteveste do you have a specific use case in mind?

esteveste commented 1 year ago

My use case is rather specific, I was approximating a function forward and backward passes with a look up table. In my experiments I've been using autograd.Function for implementing that, which says is not compatible with functorch. Thus in my case I would not need double gradients. Maybe actually being able to use directly autograd.Function on functorch would be more useful for this case.

To be honest I've given the Jax examples that referred to be able to implement custom grad functions, although not having experimented with them.

Another thing that I was experimenting with though was also to attempt using an optimized triton kernel for training (both Forward and backward), where I've been also using autograd.Function, and would be cool to be able to use it with vmap.

zou3519 commented 1 year ago

To check my understanding, it sounds like you want "an autograd.Function that works with functorch" rather than a "jax.custom_vjp in PyTorch", right? (We are planning on figuring out both, but I'm currently trying to prioritize).

The functions passed to jax.custom_vjp need to only call jax operations (and cannot call out to e.g. a custom triton kernel or a lookup table). This limitation makes it so that it's possible for the custom_vjp to automatically work with jax.vmap (a vmap rule gets generated from the forward/backward functions passed to jax.custom_vjp).

esteveste commented 1 year ago

Ups, you're right! For these cases a autograd.Function seems more appropriate.

zou3519 commented 1 year ago

My use case is rather specific, I was approximating a function forward and backward passes with a look up table.

Another thing that I was experimenting with though was also to attempt using an optimized triton kernel for training (both Forward and backward), where I've been also using autograd.Function, and would be cool to be able to use it with vmap.

To use both of these use cases for vmap, assuming that functorch supports autograd.Function, then the user would need to make an autograd.Function and define a vmap (batching) rule for it, like how in autograd.Function one defines a custom forward and backward. There isn't an easy way for the framework (functorch) to automatically generate a batching rule from just a forward pass and a backward pass. The batching rule could either be naive (just run the operation in a for-loop) or it could be a {some lookup in the table, another custom triton kernel, or if we're lucky, the same triton kernel could be reused} depending on the use case.

If the batching rule needs to be differentiated... then it may need to be another user-defined autograd.Function that has a forward and backward.

@esteveste - I wanted to check here, if we provided these APIs (an autograd.Function that works with functorch and a way to define custom rules for vmap), would that be helpful for you? I imagine it would be work to e.g. write a custom triton kernel for the batching rule so I'm not sure if all users would use this feature.

esteveste commented 1 year ago

I think that for my experiments it would be pretty interesting, but must agree with you, this would probably be a very niche application.

I guess it would be up to you guys to consider if is worth to support custom vmap rules, or just to keep a more simple compatibility with autograd.Function.

kxhit commented 1 year ago

Not yet, but we have plans to support it sometime by the end of the year.

Hi is there an update for supporting pytorch.autograd.Function? Or any alternative solution currently to do it with functorch.vmap?

zou3519 commented 1 year ago

Hi is there an update for supporting pytorch.autograd.Function? Or any alternative solution currently to do it with functorch.vmap?

Hey! It works on the latest PyTorch nightly, but the docs haven't made it to the website yet. Please read through https://github.com/pytorch/pytorch/blob/master/docs/source/notes/extending.func.rst and let us know if you have questions.

The TL;DR is that to use autograd.Function with functorch, the user may need to refactor the autograd.Function a bit into a form that we are able to transform.

zou3519 commented 1 year ago

Feature has been implemented and will be in the next PyTorch release. Docs are available over at https://pytorch.org/docs/master/notes/extending.func.html . Closing this issue as completed but please feel free to open new issues if you folks run into any problems.