pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

dedup-error checking code (from structured kernels) with batching rules #1068

Open zou3519 opened 1 year ago

zou3519 commented 1 year ago

Problem

Today, some of our batching rules need to duplicate error checking code. For example, consider https://github.com/pytorch/pytorch/pull/89384, which adds a check to the structured kernel that the input to torch.triu has at least 2 dimensions.

if we did the following:

x = torch.randn(32, 2)
vmap(torch.triu)(x)

today it would succeed, but we should expect it to fail, because inside vmap, x is logically a 1D tensor

The batching rule today:

This does not trigger at::triu's error check, because the input passed to at::triu is indeed 2D. The fix is simple: add a check inside the batching rule that the logical dimension of the input is at least 2D.

In the interest of not repeating ourselves it would be cool if there were some way to reuse the error checks from the structured kernel.

Pitch

I'm not familiar with how the error checks get code generated, but there's probably some way to add some more codegen to vmap so that we can automagically add the error check and have it run before a batching rule.

cc @bdhirsh @ezyang if you folks have thoughts

ezyang commented 1 year ago

So, an "easy" automatic way to do this is to run the meta implementation with the logical size (batched dimensions removed). It is not very efficient. How does that sound?

zou3519 commented 1 year ago

By "run the meta implementation", do you mean create meta tensors and run the operation on them?

Ideally we would turn this on for all batching rules by default (i.e. if they have a structured kernel definition, we just run the error checking) so we can eliminate this class of errors. The error checks would also run at each level of vmap.

In that model running the meta implementation is likely too much overhead (creating a meta tensor involves Tensor allocation)

ezyang commented 1 year ago

Yes, create meta tensors. Bah, that is annoying.

The big problem is structured kernels aren't really setup for the kind of error test you want. You want foo_error_check(blah, blah) which doesn't bother allocating the outputs and just does stuff like the 2D test. But to do this, we have to split up the meta func (because it both does error checks AND allocates the output tensors), and you also have to define what the arguments are (because meta func is batching oblivious.) Actually, this is a situation where a vmap mode that magically transmutes tensors as is would be helpful, since you would turn on the mode and then call the conventional function, and it would indeed see only a 1D tensor.

bdhirsh commented 1 year ago

Is the tensor allocation the main source of unacceptable overhead?

We could probably figure out a way to avoid it, but another source of overhead is that a lot of our meta kernels are now in python, and might involve decompositions (into more python code).