import torch
import functorch
dtype = torch.float32
device = torch.device('cpu')
def foo(x):
return x + 1.0
x = torch.tensor([[0.0]], dtype=dtype, device=device)
functorch.make_fx(functorch.vmap(foo))(x) # Works
functorch.make_fx(functorch.jacrev(foo))(x) # Works
functorch.make_fx(functorch.jacfwd(foo))(x) # Fails
Error Message:
```
Traceback (most recent call last):
File "/home/kshiteej/Pytorch/pytorch_functorch/test/test_scratch.py", line 31, in
functorch.make_fx(functorch.jacfwd(foo))(x) # Fails
File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/experimental/proxy_tensor.py", line 683, in wrapped
t = dispatch_trace(wrap_key(func, args, fx_tracer), tracer=fx_tracer, concrete_args=tuple(phs))
File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/experimental/proxy_tensor.py", line 441, in dispatch_trace
graph = tracer.trace(root, concrete_args)
File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/_symbolic_trace.py", line 739, in trace
(self.create_arg(fn(*args)),),
File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/experimental/proxy_tensor.py", line 457, in wrapped
out = f(*tensors)
File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/_functorch/eager_transforms.py", line 996, in wrapper_fn
results = vmap(push_jvp, randomness=randomness)(basis)
File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/_functorch/vmap.py", line 362, in wrapped
return _flat_vmap(
File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/_functorch/vmap.py", line 35, in fn
return f(*args, **kwargs)
File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/_functorch/vmap.py", line 489, in _flat_vmap
batched_outputs = func(*batched_inputs, **kwargs)
File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/_functorch/eager_transforms.py", line 989, in push_jvp
output = _jvp_with_argnums(func, args, basis, argnums=argnums, has_aux=has_aux)
File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/_functorch/vmap.py", line 35, in fn
return f(*args, **kwargs)
File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/_functorch/eager_transforms.py", line 837, in _jvp_with_argnums
result_duals = func(*duals)
File "/home/kshiteej/Pytorch/pytorch_functorch/test/test_scratch.py", line 26, in foo
return x + 1.0
File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/experimental/proxy_tensor.py", line 483, in __torch_dispatch__
return self.inner_torch_dispatch(func, types, args, kwargs)
File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/experimental/proxy_tensor.py", line 508, in inner_torch_dispatch
out = proxy_call(self, func, args, kwargs)
File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/experimental/proxy_tensor.py", line 259, in proxy_call
r = func.decompose(*args, **kwargs)
File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/_ops.py", line 307, in decompose
return self._op_dk(dk, *args, **kwargs)
File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/experimental/proxy_tensor.py", line 483, in __torch_dispatch__
return self.inner_torch_dispatch(func, types, args, kwargs)
File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/experimental/proxy_tensor.py", line 508, in inner_torch_dispatch
out = proxy_call(self, func, args, kwargs)
File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/experimental/proxy_tensor.py", line 393, in proxy_call
track_tensor_tree(out, proxy_out, constant=constant, tracer=tracer)
File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/experimental/proxy_tensor.py", line 206, in track_tensor_tree
wrap_with_proxy(inner_res, proxy_res, constant)
File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/experimental/proxy_tensor.py", line 185, in wrap_with_proxy
set_meta(proxy, e)
File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/experimental/proxy_tensor.py", line 149, in set_meta
proxy.node.meta['val'] = torch.empty_strided(val.shape, val.stride(), device=val.device, dtype=val.dtype)
File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 878, in __torch_dispatch__
op_impl_out = op_impl(self, func, *args, **kwargs)
File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 325, in constructors
return FakeTensor(fake_mode, r, out_device)
File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 560, in __init__
assert device.type != "meta"
AssertionError
```
Error Message: