pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

vmap should accept a dim_size=None argument #1081

Open zou3519 opened 1 year ago

zou3519 commented 1 year ago

vmap should accept a dim_size=None argument where the user is allowed to specify the size of the dimension being vmapped over. Should behave similarly to JAX's axis_name argument.

The net effect of this is that one should be able to vmap over functions that do not take Tensors as input!

def f():
  return torch.tensor(1.)

result = vmap(f, dim_size=5)()
assert torch.allclose(result, torch.tensor([1., 1., 1., 1., 1.]))

We should also investigate if there are other things that the axis_size arg in JAX provides.