Open xidulu opened 3 years ago
I agree that this is annoying, but I think the root problem is a bit subtler. The functionality that vmap provides is this:
In other words, if f
works fine with a certain shape of tensor, then vmap(f)
will work fine with that same shape + a batch dimension.
If we applied this to F.conv2d, this is how it should look like.
In other words, we can vmap fine over F.conv2d, as long as we consider the original function to already have a batch dimension :)
The best solution for this would probably be to allow F.conv2d to take in a rank 3 tensor that doesn't have a batch dimension.
EDIT: I notice that Matt Johnson said very similar things here: https://github.com/google/jax/issues/381#issuecomment-464093876
Also, could you explain your use case? That would help us in figuring out how best to solve your problem.
@Chillee Thanks for your reply, I am currently developing a torch-based library, in which we would automatically add some dimensions in front of user-defined variables. (To be more specific, we add a dimension that represents the index of independent monte carlo samples)
We would like our library to be general and universal, i.e. the users can define any "element-wise" function they want and our lib can automatically transform users' program into a batching version, which means we cannot rely on the broadcasting mechanism as some operators are not supported by broadcasting (e.g. dot product). Therefore, vmap is the best option.
The problem here, is that, most of torch's function is well supported by vmap (torch.add, torch.dot, torch.matmtul, F.linear, ...) and the users could forget about the batch dimension. However, F.conv2d seems to be a little bit inconsistent therefore we need to handle F.conv2d with some special hacks. (e.g. turn down vmap and switch back to broadcasting when the function is F.conv2d).
And I think the best solution would be to allow F.conv2d to support rank 3 tensor. If that's not straightforward, I thinks it could be very helpful to leave some notes in the documentation of functorch.vmap? As you obviously need to apply vmap on F.conv2d in a way different from other operators.
We are indeed planning on allowing F.conv2d to support rank 3 tensor. This work is being planned over at https://github.com/pytorch/pytorch/issues/60585; not only are we changing F.conv2d, we also need to change many of the other torch.nn functions that traditionally accept tensors with a single batch dimension.
If you're interested in contributing to get this feature in faster... feel free to express interest on https://github.com/pytorch/pytorch/issues/60585
@zou3519 Got it. I think I have the capacity to contribute to this a little bit. I will comment below Issue #60585 later.
F.conv2d forces the input to be a 4-D tensor where the first dimension is a batch dimension. But when you use vmap over the batch dimension of the input, the input would become a 3-D tensor and torch would raise an exception.
However, F.linear would not complain on this.
This problem also occurs in jax: https://github.com/google/jax/issues/381