pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.39k stars 102 forks source link

Allow batch norm with all variations of batching when training=False #958

Closed samdow closed 2 years ago

samdow commented 2 years ago

Closes #867

This updates the batch norm check so that if we aren't training, we don't error since no inplace update will occur

The testing on this got messy. Now we can test ever combination of batching on the input, running_mean, and running_var (previously we only did a subset).

Basically I knew that the only bool argument in an instance norm or batch norm call must be training because of the signature. So I looked for that in the kwargs. However, in the composition tests (like vmapvjp) the kwargs are captured by the vmap-ed over function, no longer inputs. So, this needed to refactor get_fallback_and_vmap_exhaustive to take a flag as to whether it's using a function that uses batch norm and that function is training instead of taking the opinfo and doing a name match on that. The helper function to figure this out takes in the name of the op, grabbed from the opinfo, and the kwargs being used by the current sample input

samdow commented 2 years ago

This was merged but I'm still concerned of batch norm's ability to support non-contiguous running means/vars. I will investigate

JackShi9 commented 1 year ago

@samdow Hi! I ran into the same problem as in #867 . I'm trying to use a pre-trained network to extract features and need to compute the sample-wise Jacobian (i.e. vmap(jacrev)) and got the following error ValueError: expected 4D input (got 3D input) I'm sure the model is in .eval() mode so is there any pointer you can give me to make this work? Thanks!

samdow commented 1 year ago

Hi @JackShi9! What's happening in your code is that, because we're vmapping over the input, it's being treated as a 3D tensor instead of the original 4D tensor. The basics of vmap is that it "hides" the dimension being vmapped over, meaning that the rank of the tensor is 1 less than the underlying tensor. So if the input works without vmap, we'll have to adjust it to work with vmap

A couple options that might help: (1) unsqueeze the input before you run the code. Assuming that it works to run the input not under vmap, the input is of size: [B, ...]. It should work to have the input be of size [B, 1, ...]. This is because once vmapped, the input will be seen as [1, ...] which has the right dimensions. Note that this means that your output will have an extra dimension in it (2) Use layer normalization. Sadly, most of the norm layers also require a batch dimension. However, I think LayerNorm will work if you can swap it in

Let me know if there's more questions! Also happy to look at a code snippet if it helps

JackShi9 commented 1 year ago

Thanks for the quick response @samdow, the unsqueezing does the trick!

JackShi9 commented 1 year ago

Another quick question @samdow : when vmap is calculating the Jacobian of the output with respect to some function (i.e. vmap(jacrev(some_function))(input_to_the_funtion)), is it storing the gradient information of the same function as many times as the batch size of the input? I assume not as this would be very computationally expensive (and unnecessary?) and the memory would run out for any decent-sized network and even moderate/small batch size. Is my suspicion right?

samdow commented 1 year ago

when vmap is calculating the Jacobian of the output with respect to some function (i.e. vmap(jacrev(some_function))(input_to_the_funtion)), is it storing the gradient information of the same function as many times as the batch size of the input? I assume not as this would be very computationally expensive (and unnecessary?) and the memory would run out for any decent-sized network and even moderate/small batch size. Is my suspicion right?

Your suspicion is right that typically we don't do that. vmap does "autobatching" meaning that it replaces what would be multiple kernel calls in a for loop with a single, batched kernel call. This takes advantage of the fact that most PyTorch operations can deal with arbitrary dimensions. So, this should only store the gradient information once.

However, if you ever see a warning that looks like: There is a performance drop because we have not yet implemented the batching rule for <some batching rule> This means that we will be running a for loop for the number of batches, which as you point out is inefficient. If you do run into this, please file an issue so we prioritize it!

JackShi9 commented 1 year ago

Got it. Thanks so much @samdow , this is really helpful!

JackShi9 commented 1 year ago

Hi @samdow! I ran into another problem with vmap and functorch. I am currently trying out different parallelization schemes, i.e. dataparallel and distributeddataparallel. vmap and functorch work fine when I use model= nn.DataParallel(model), but when I use distributeddataparallel the error mentioned in #867 comes up again (specifically this), despite the fact that the model is in .eval() mode. To apply a quick fix, I used replace_all_batch_norm_modules_ to patch the batchnorms, but then received this error

RuntimeError: functorch functions (vmap, grad, vjp, etc.) currently do not support the use of autograd.Function. Please rewrite your function to not use autograd.Function while we work on fixing this I am not sure what this error implies/means, but my hunch (which might very well be wrong) is that the problem is coming from model = nn.SyncBatchNorm.convert_sync_batchnorm(model) that is required for distributeddataparallel. Do you have any suggested quick fix or am I making mistakes somewhere?