pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.4k stars 102 forks source link

nn.functional.conv2d is not very compatible with vmap #71

Open xidulu opened 3 years ago

xidulu commented 3 years ago

F.conv2d forces the input to be a 4-D tensor where the first dimension is a batch dimension. But when you use vmap over the batch dimension of the input, the input would become a 3-D tensor and torch would raise an exception.

However, F.linear would not complain on this.

This problem also occurs in jax: https://github.com/google/jax/issues/381

Chillee commented 3 years ago

I agree that this is annoying, but I think the root problem is a bit subtler. The functionality that vmap provides is this:

  1. Let's say we have a function f, that takes in a tensor of shape [X,Y,Z], and outputs a tensor of shape [X].
  2. Now, vmap(f) will take in a tensor of shape [B, X, Y, Z], and output a tensor of shape [B, X].

In other words, if f works fine with a certain shape of tensor, then vmap(f) will work fine with that same shape + a batch dimension.

If we applied this to F.conv2d, this is how it should look like.

  1. We have a function F.conv2d, that takes in a tensor of shape [N,C,H,W], and outputs a tensor of shape [N,C,H,W].
  2. Now, vmap(F.conv2d) will take in a tensor of shape [B, N, C, H, W] and output a tensor of shape [B, N, C, H, W].

In other words, we can vmap fine over F.conv2d, as long as we consider the original function to already have a batch dimension :)

The best solution for this would probably be to allow F.conv2d to take in a rank 3 tensor that doesn't have a batch dimension.

EDIT: I notice that Matt Johnson said very similar things here: https://github.com/google/jax/issues/381#issuecomment-464093876

Chillee commented 3 years ago

Also, could you explain your use case? That would help us in figuring out how best to solve your problem.

xidulu commented 3 years ago

@Chillee Thanks for your reply, I am currently developing a torch-based library, in which we would automatically add some dimensions in front of user-defined variables. (To be more specific, we add a dimension that represents the index of independent monte carlo samples)

We would like our library to be general and universal, i.e. the users can define any "element-wise" function they want and our lib can automatically transform users' program into a batching version, which means we cannot rely on the broadcasting mechanism as some operators are not supported by broadcasting (e.g. dot product). Therefore, vmap is the best option.

The problem here, is that, most of torch's function is well supported by vmap (torch.add, torch.dot, torch.matmtul, F.linear, ...) and the users could forget about the batch dimension. However, F.conv2d seems to be a little bit inconsistent therefore we need to handle F.conv2d with some special hacks. (e.g. turn down vmap and switch back to broadcasting when the function is F.conv2d).

And I think the best solution would be to allow F.conv2d to support rank 3 tensor. If that's not straightforward, I thinks it could be very helpful to leave some notes in the documentation of functorch.vmap? As you obviously need to apply vmap on F.conv2d in a way different from other operators.

zou3519 commented 3 years ago

We are indeed planning on allowing F.conv2d to support rank 3 tensor. This work is being planned over at https://github.com/pytorch/pytorch/issues/60585; not only are we changing F.conv2d, we also need to change many of the other torch.nn functions that traditionally accept tensors with a single batch dimension.

If you're interested in contributing to get this feature in faster... feel free to express interest on https://github.com/pytorch/pytorch/issues/60585

xidulu commented 3 years ago

@zou3519 Got it. I think I have the capacity to contribute to this a little bit. I will comment below Issue #60585 later.