Closed samdow closed 2 years ago
This was merged but I'm still concerned of batch norm's ability to support non-contiguous running means/vars. I will investigate
@samdow Hi! I ran into the same problem as in #867 . I'm trying to use a pre-trained network to extract features and need to compute the sample-wise Jacobian (i.e. vmap(jacrev)
) and got the following error
ValueError: expected 4D input (got 3D input)
I'm sure the model is in .eval()
mode so is there any pointer you can give me to make this work? Thanks!
Hi @JackShi9! What's happening in your code is that, because we're vmapping over the input, it's being treated as a 3D tensor instead of the original 4D tensor. The basics of vmap is that it "hides" the dimension being vmapped over, meaning that the rank of the tensor is 1 less than the underlying tensor. So if the input works without vmap, we'll have to adjust it to work with vmap
A couple options that might help:
(1) unsqueeze the input before you run the code. Assuming that it works to run the input not under vmap, the input is of size: [B, ...]
. It should work to have the input be of size [B, 1, ...]
. This is because once vmapped, the input will be seen as [1, ...]
which has the right dimensions. Note that this means that your output will have an extra dimension in it
(2) Use layer normalization. Sadly, most of the norm layers also require a batch dimension. However, I think LayerNorm will work if you can swap it in
Let me know if there's more questions! Also happy to look at a code snippet if it helps
Thanks for the quick response @samdow, the unsqueezing does the trick!
Another quick question @samdow : when vmap is calculating the Jacobian of the output with respect to some function (i.e. vmap(jacrev(some_function))(input_to_the_funtion)
), is it storing the gradient information of the same function as many times as the batch size of the input? I assume not as this would be very computationally expensive (and unnecessary?) and the memory would run out for any decent-sized network and even moderate/small batch size. Is my suspicion right?
when vmap is calculating the Jacobian of the output with respect to some function (i.e. vmap(jacrev(some_function))(input_to_the_funtion)), is it storing the gradient information of the same function as many times as the batch size of the input? I assume not as this would be very computationally expensive (and unnecessary?) and the memory would run out for any decent-sized network and even moderate/small batch size. Is my suspicion right?
Your suspicion is right that typically we don't do that. vmap
does "autobatching" meaning that it replaces what would be multiple kernel calls in a for loop with a single, batched kernel call. This takes advantage of the fact that most PyTorch operations can deal with arbitrary dimensions. So, this should only store the gradient information once.
However, if you ever see a warning that looks like:
There is a performance drop because we have not yet implemented the batching rule for <some batching rule>
This means that we will be running a for loop for the number of batches, which as you point out is inefficient. If you do run into this, please file an issue so we prioritize it!
Got it. Thanks so much @samdow , this is really helpful!
Hi @samdow! I ran into another problem with vmap
and functorch
. I am currently trying out different parallelization schemes, i.e. dataparallel
and distributeddataparallel
. vmap
and functorch
work fine when I use model= nn.DataParallel(model)
, but when I use distributeddataparallel
the error mentioned in #867 comes up again (specifically this), despite the fact that the model is in .eval()
mode. To apply a quick fix, I used replace_all_batch_norm_modules_
to patch the batchnorms, but then received this error
RuntimeError: functorch functions (vmap, grad, vjp, etc.) currently do not support the use of autograd.Function. Please rewrite your function to not use autograd.Function while we work on fixing this
I am not sure what this error implies/means, but my hunch (which might very well be wrong) is that the problem is coming from model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
that is required for distributeddataparallel
. Do you have any suggested quick fix or am I making mistakes somewhere?
Closes #867
This updates the batch norm check so that if we aren't training, we don't error since no inplace update will occur
The testing on this got messy. Now we can test ever combination of batching on the input, running_mean, and running_var (previously we only did a subset).
Basically I knew that the only bool argument in an instance norm or batch norm call must be training because of the signature. So I looked for that in the kwargs. However, in the composition tests (like vmapvjp) the kwargs are captured by the vmap-ed over function, no longer inputs. So, this needed to refactor
get_fallback_and_vmap_exhaustive
to take a flag as to whether it's using a function that uses batch norm and that function is training instead of taking the opinfo and doing a name match on that. The helper function to figure this out takes in the name of the op, grabbed from the opinfo, and the kwargs being used by the current sample input