pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.4k stars 102 forks source link

Use vmap with kwargs #70

Open xidulu opened 3 years ago

xidulu commented 3 years ago

Hi

I am currently having the following use case:

from functorch import vmap
x = torch.randn(2,10)
w = torch.randn(2,5,10)
b = torch.randn(2,5)
print(vmap(F.linear, in_dims=(0, 0, 0))(x, w, bias=b).shape)

However, this would raise exception:

ValueError: vmap(linear, in_dims=(0, 0, 0), ...)(<inputs>): in_dims is not compatible with the structure of `inputs`. in_dims has structure TreeSpec(tuple, None, [*, *, *]) but inputs has structure TreeSpec(tuple, None, [*, *]).

I also tried

from functorch import vmap
x = torch.randn(2,10)
w = torch.randn(2,5,10)
b = torch.randn(2,5)
print(vmap(F.linear, in_dims=(0, 0, {'bias':0}))(x, w, bias=b).shape)

which also does not work.

I am wondering what's the correct way to specify the in_dims for keyword arguments in **kwargs ?Or is it the case that vmap in_dims only accept positions?

Thanks

zou3519 commented 3 years ago

vmap in_dims only accepts positional args; this follows the behavior of jax.vmap. For F.linear, is it possible to workaround this by passing bias as a positional arg for now?

We probably do want some way for vmap to specify that it is mapping over kwargs, especially because PyTorch operators can have kwarg-only arguments. Your second approach, (print(vmap(F.linear, in_dims=(0, 0, {'bias':0}))(x, w, bias=b).shape)) seems pretty reasonable as an API

xidulu commented 3 years ago

@zou3519

Using bias as positional arguments can solve the problem

print(vmap(F.linear, in_dims=(0, 0, 0))(x, w, b).shape)

And you are right... My usage seems not to be supported by jax either.

BTW, I notice that, even if I am passing bias as positional arg, when I step into the function call, bias is still interpreted as a keyword argument? (I can see it inside **kwargs) Is that what you mean by PyTorch operators can have kwarg-only arguments?

If that is the case, I think a twist to the vmap API is indeed very necessary 🤔

zou3519 commented 3 years ago

BTW, I notice that, even if I am passing bias as positional arg, when I step into the function call, bias is still interpreted as a keyword argument? (I can see it inside **kwargs)

Does this happen when you use __torch_function__? If so, the reason is because it's being treated as a keyword arg here: https://github.com/pytorch/pytorch/blob/d46689a2017cc046abdc938247048952df4f6de7/torch/nn/functional.py#L1846. I can look into why this actually is this way; I always thought it was weird

Is that what you mean by PyTorch operators can have kwarg-only arguments?

In Python it's possible to define a function with an argument that must be passed as a kwarg (and cannot be passed as positional). There are some examples here: https://python-3-for-scientists.readthedocs.io/en/latest/python3_advanced.html and some PyTorch operators are written this way. F.linear isn't, though

xidulu commented 3 years ago

Yes! It happens when using "__torch_function__" And that's indeed, little bit weird.