pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

tree-based level system #1076

Open zou3519 opened 1 year ago

zou3519 commented 1 year ago

Motivation

We have a limitation in that if someone has a custom operator, the custom operator kernels for functorch transforms are not allowed to call a functorch transform. This is problematic for developers (e.g. the functionalize rule for cond should call functionalize on both branches) and for users (autograd.Function with functorch -- e.g. a user may not call a functorch transform inside backward)

Note that in general, users are able to call functorch transforms from within functorch transforms (e.g. vmap(vmap(vmap(f)))(x)); the distinction here is that one cannot call a functorch transform from inside an implementation of an operation.

For further reading

More information in doc: https://docs.google.com/document/d/1iZ_ioIC8u6osb87Ue43yGrr4ACIQRx9dMlqg2KQx2co/edit#heading=h.49xwcme71bbe

cc @samdow @Chillee @ezyang -- the notion of "levels being a tree" sounds similar to some of the torch_dispatch + mode + functorch generalizations we thought through last year.

ezyang commented 1 year ago

Wait, so the inner usage of functorch has wrapper tensors that are escaping from their level?

zou3519 commented 1 year ago

Those wrapper tensors aren't escaping from their level. Wrapper tensors with the same level (but that come from different interpreters!) should never interact, assuming a correctly implemented functorch (and non-adversarial user code)

The immediate problem is that we have some internal asserts somewhere that say that there cannot be two different alive interpreters with the same level