pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

vmap should be able to accept `None` in its out_dims argument #1082

Closed zou3519 closed 1 year ago

zou3519 commented 1 year ago

Currently, all the outputs of vmap must be Tensors. We can relax this constraint by letting the user specify None for an out_dim (which means, return the object as-is).

def f(x):
  return 1

x = torch.randn(3)
result = vmap(f, out_dims=(None,))(x)
assert result == 1

There are some interesting cases (that we should look to JAX to for what the semantics should be):

kshitij12345 commented 1 year ago

Fixed in https://github.com/pytorch/pytorch/pull/91644