Currently, all the outputs of vmap must be Tensors. We can relax this constraint by letting the user specify None for an out_dim (which means, return the object as-is).
def f(x):
return 1
x = torch.randn(3)
result = vmap(f, out_dims=(None,))(x)
assert result == 1
There are some interesting cases (that we should look to JAX to for what the semantics should be):
what happens if we specify out_dims=None on a Tensor return, but the Tensor is being vmapped over?
Currently, all the outputs of vmap must be Tensors. We can relax this constraint by letting the user specify
None
for an out_dim (which means, return the object as-is).There are some interesting cases (that we should look to JAX to for what the semantics should be):