pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

Printing unwrapped tensors doesn't work under grad-type transform #1026

Closed samdow closed 1 year ago

samdow commented 1 year ago

Repro

x = torch.randn(3)
def f(y):
  print(x)
  return y
vjp(f, torch.randn(4)) # or equivalent with grad or jvp

Gets error: RuntimeError: Cannot access data pointer of Tensor that doesn't have storage

Full Stacktrace ``` Traceback (most recent call last): File "", line 1, in File "/Users/samdow/Documents/jacfwd_fix/functorch/functorch/_src/eager_transforms.py", line 270, in vjp primals_out = func(*diff_primals) File "", line 2, in f File "/Users/samdow/Documents/jacfwd_fix/torch/_tensor.py", line 423, in __repr__ return torch._tensor_str._str(self, tensor_contents=tensor_contents) File "/Users/samdow/Documents/jacfwd_fix/functorch/functorch/_src/monkey_patching.py", line 23, in _functorch_str return _old_str(tensor) File "/Users/samdow/Documents/jacfwd_fix/torch/_tensor_str.py", line 594, in _str return _str_intern(self, tensor_contents=tensor_contents) File "/Users/samdow/Documents/jacfwd_fix/torch/_tensor_str.py", line 557, in _str_intern tensor_str = _tensor_str(self, indent) File "/Users/samdow/Documents/jacfwd_fix/torch/_tensor_str.py", line 320, in _tensor_str return _tensor_str_with_formatter(self, indent, summarize, formatter) File "/Users/samdow/Documents/jacfwd_fix/torch/_tensor_str.py", line 249, in _tensor_str_with_formatter return _vector_str(self, indent, summarize, formatter1, formatter2) File "/Users/samdow/Documents/jacfwd_fix/torch/_tensor_str.py", line 230, in _vector_str data = [_val_formatter(val) for val in self.tolist()] RuntimeError: Cannot access data pointer of Tensor that doesn't have storage ```

What we think is happening

There's some monkey patching to make printing work. We think that since the transform is still on, whatever functions normal printing calls are getting wrapped in tensor wrappers that don't have storage

Potential Solutions

An option could be to turn off all the transforms while printing (or even when just printing an unwrapped tensor)

zou3519 commented 1 year ago

I'm surprised that no one has tried to do this until now

zou3519 commented 1 year ago

Fixed in https://github.com/pytorch/pytorch/pull/85556