pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

Better "oops your tensor escaped from vmap" error message #1054

Closed zou3519 closed 1 year ago

zou3519 commented 1 year ago

Right now what happens is upon usage of the Tensor, it will internal assert, which is confusing for users and looks like a bug rather than expected behavior. We should raise a more descriptive error message.

zou3519 commented 1 year ago

The repro is:

import torch
from functorch import vmap

lst = []

def f(x):
  lst.append(x)
  return x ** 2

x = torch.randn(3)
vmap(f)(x)

lst[0].sin()
# RuntimeError: maybe_layer.has_value() INTERNAL ASSERT FAILED at "/raid/rzou/pt/ctrl/build/ate
# n/src/ATen/VmapGeneratedPlumbing.h":10652, please report a bug to PyTorch.=

We should improve the error message to also say that (in addition to being potentially a framework error), the user may be escaping a Tensor from vmap, which is not allowed (as per https://pytorch.org/functorch/stable/ux_limitations.html).

srossross commented 1 year ago

this is fixed with https://github.com/pytorch/pytorch/pull/89077