Open xidulu opened 3 years ago
vmap in_dims
only accepts positional args; this follows the behavior of jax.vmap. For F.linear, is it possible to workaround this by passing bias as a positional arg for now?
We probably do want some way for vmap to specify that it is mapping over kwargs, especially because PyTorch operators can have kwarg-only arguments. Your second approach, (print(vmap(F.linear, in_dims=(0, 0, {'bias':0}))(x, w, bias=b).shape)
) seems pretty reasonable as an API
@zou3519
Using bias as positional arguments can solve the problem
print(vmap(F.linear, in_dims=(0, 0, 0))(x, w, b).shape)
And you are right... My usage seems not to be supported by jax either.
BTW, I notice that, even if I am passing bias as positional arg, when I step into the function call, bias
is still interpreted as a keyword argument? (I can see it inside **kwargs)
Is that what you mean by PyTorch operators can have kwarg-only arguments
?
If that is the case, I think a twist to the vmap API is indeed very necessary 🤔
BTW, I notice that, even if I am passing bias as positional arg, when I step into the function call, bias is still interpreted as a keyword argument? (I can see it inside **kwargs)
Does this happen when you use __torch_function__
? If so, the reason is because it's being treated as a keyword arg here: https://github.com/pytorch/pytorch/blob/d46689a2017cc046abdc938247048952df4f6de7/torch/nn/functional.py#L1846. I can look into why this actually is this way; I always thought it was weird
Is that what you mean by PyTorch operators can have kwarg-only arguments?
In Python it's possible to define a function with an argument that must be passed as a kwarg (and cannot be passed as positional). There are some examples here: https://python-3-for-scientists.readthedocs.io/en/latest/python3_advanced.html and some PyTorch operators are written this way. F.linear isn't, though
Yes! It happens when using "__torch_function__" And that's indeed, little bit weird.
Hi
I am currently having the following use case:
However, this would raise exception:
I also tried
which also does not work.
I am wondering what's the correct way to specify the
in_dims
for keyword arguments in **kwargs ?Or is it the case that vmapin_dims
only accept positions?Thanks