locuslab / deq

[NeurIPS'19] Deep Equilibrium Models
MIT License
721 stars 80 forks source link

Does MDEQ have different inference results for different batch sizes? #21

Closed undercutspiky closed 2 years ago

undercutspiky commented 2 years ago

I'm running some experiments with MDEQ on ImageNet validation set and I get different activations (variable new_z1 in the mdeq_core.py) for the DEQ layer for different batch sizes. I can see in the broyden function that there's no loop over the batch but since I'm not familiar with Broyden's method and its implementation, I do not know if different images (within a batch) can interfere with each other or not directly or indirectly (by having an effect on the number of iterations in the solver for instance). Should I run inference on 1 image at a time?

jerrybai1995 commented 2 years ago

Hi @undercutspiky ,

Thanks for the question! Theoretically there shouldn't be any differences because of batch size, as all operations are conducted in parallel across samples in a batch. E.g., the code uses einsum that keeps the batch dimension intact (see https://github.com/locuslab/deq/blob/master/lib/solvers.py#L118).

However, there could be some other factors that affect this. For example, if your layer f has batch normalization in it (which I intentionally excluded and only used LayerNorm). You may also see different activations if you use dropout at training time, as that is what the random mask is all about.

If you are doing neither of the two above, then I'd be curious to know 1) if Broyden's method converged properly (you should be able to tell this via abs_trace and rel_trace; and 2) how much difference (i.e., ||new_z1_bsz1 - new_z1_bsz2|| is there for the new_z1 that you obtained with different batch sizes?

undercutspiky commented 2 years ago

Hi @jerrybai1995 ,

Thanks for replying.

I'm just running inference on your pre-trained model so I did not modify any of the normalisation layers.

I think the batch size could still have an effect because the absolute and the relative differences are calculated for the whole batch and used to break out of the function, right?

Here are the numbers you requested:

Withnz2 = new_z1 for batch size 4 and nz1 = new_z1 for batch size 2, the value of torch.norm(nz2[:2,:,:] - nz1) is 19.165.


EDIT: I checked trace_dict's lists for abs and rel. For batch size 2, the lowest point was 1st reached at iteration 25 so the values at the point should've been recorded as the final values. But for batch size 4, the lowest point was 1st reached at iteration 26 instead so it did an extra iteration. I guess the difference would be minimal but because of this extra iteration right?

jerrybai1995 commented 2 years ago

Hi @undercutspiky ,

Thanks for providing these numbers!

Yes, you are definitely right that batch size does affect both abs and rel. So that does make a difference. Generally, I (personally) consider a relative residual < 0.1 to be "ok convergence", so the 0.18 and 0.28 rel you reported seem a bit high.

On the other hand, the iteration numbers (25 and 26) you reported seem reasonable. This paper (see Table 3 in Appendix C) reproduced MDEQ-Small on Cityscapes and you can see that the performance they obtained quickly saturated after >20 iterations. Generally, as we increase the iteration count further, I do expect the difference between batch size 2 and batch size 4 to be very small. But you are right that in this particular case since there is no batchnorm, the difference could be because of this extra iteration.

One interesting to try here may be to set the iteration threshold to 50 steps (or more). As long as the layer is converging, batch size should not make a huge difference.

undercutspiky commented 2 years ago

Hi @jerrybai1995 ,

Thank you very much for verifying and for all the information. That all clears it up then.

Before I close the issue, I have another unrelated question. Using MDEQ-Small model for ImageNet throws an error while loading the model (I've been using MDEQ-XL). Do I need to change something in the yaml file to make it work? Should I create another issue for this?

jerrybai1995 commented 2 years ago

Can you check if this issue solves the problem?

undercutspiky commented 2 years ago

Yes, that solved the issue. Thanks a lot!