pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

make_fx fails with `jacfwd` (when used with torch.add(Tensor, Scalar)) #1078

Open kshitij12345 opened 1 year ago

kshitij12345 commented 1 year ago
import torch
import functorch

dtype = torch.float32
device = torch.device('cpu')

def foo(x):
    return x + 1.0

x = torch.tensor([[0.0]], dtype=dtype, device=device)

functorch.make_fx(functorch.vmap(foo))(x)  # Works
functorch.make_fx(functorch.jacrev(foo))(x)  # Works
functorch.make_fx(functorch.jacfwd(foo))(x)  # Fails

Error Message:

``` Traceback (most recent call last): File "/home/kshiteej/Pytorch/pytorch_functorch/test/test_scratch.py", line 31, in functorch.make_fx(functorch.jacfwd(foo))(x) # Fails File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/experimental/proxy_tensor.py", line 683, in wrapped t = dispatch_trace(wrap_key(func, args, fx_tracer), tracer=fx_tracer, concrete_args=tuple(phs)) File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/experimental/proxy_tensor.py", line 441, in dispatch_trace graph = tracer.trace(root, concrete_args) File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/_symbolic_trace.py", line 739, in trace (self.create_arg(fn(*args)),), File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/experimental/proxy_tensor.py", line 457, in wrapped out = f(*tensors) File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/_functorch/eager_transforms.py", line 996, in wrapper_fn results = vmap(push_jvp, randomness=randomness)(basis) File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/_functorch/vmap.py", line 362, in wrapped return _flat_vmap( File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/_functorch/vmap.py", line 35, in fn return f(*args, **kwargs) File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/_functorch/vmap.py", line 489, in _flat_vmap batched_outputs = func(*batched_inputs, **kwargs) File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/_functorch/eager_transforms.py", line 989, in push_jvp output = _jvp_with_argnums(func, args, basis, argnums=argnums, has_aux=has_aux) File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/_functorch/vmap.py", line 35, in fn return f(*args, **kwargs) File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/_functorch/eager_transforms.py", line 837, in _jvp_with_argnums result_duals = func(*duals) File "/home/kshiteej/Pytorch/pytorch_functorch/test/test_scratch.py", line 26, in foo return x + 1.0 File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/experimental/proxy_tensor.py", line 483, in __torch_dispatch__ return self.inner_torch_dispatch(func, types, args, kwargs) File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/experimental/proxy_tensor.py", line 508, in inner_torch_dispatch out = proxy_call(self, func, args, kwargs) File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/experimental/proxy_tensor.py", line 259, in proxy_call r = func.decompose(*args, **kwargs) File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/_ops.py", line 307, in decompose return self._op_dk(dk, *args, **kwargs) File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/experimental/proxy_tensor.py", line 483, in __torch_dispatch__ return self.inner_torch_dispatch(func, types, args, kwargs) File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/experimental/proxy_tensor.py", line 508, in inner_torch_dispatch out = proxy_call(self, func, args, kwargs) File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/experimental/proxy_tensor.py", line 393, in proxy_call track_tensor_tree(out, proxy_out, constant=constant, tracer=tracer) File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/experimental/proxy_tensor.py", line 206, in track_tensor_tree wrap_with_proxy(inner_res, proxy_res, constant) File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/experimental/proxy_tensor.py", line 185, in wrap_with_proxy set_meta(proxy, e) File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/experimental/proxy_tensor.py", line 149, in set_meta proxy.node.meta['val'] = torch.empty_strided(val.shape, val.stride(), device=val.device, dtype=val.dtype) File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 878, in __torch_dispatch__ op_impl_out = op_impl(self, func, *args, **kwargs) File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 325, in constructors return FakeTensor(fake_mode, r, out_device) File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 560, in __init__ assert device.type != "meta" AssertionError ```
zou3519 commented 1 year ago

Could be related to https://github.com/pytorch/pytorch/issues/90065