keyu-tian / SparK

[ICLR'23 Spotlight🔥] The first successful BERT/MAE-style pretraining on any convolutional network; Pytorch impl. of "Designing BERT for Convolutional Networks: Sparse and Hierarchical Masked Modeling"
https://arxiv.org/abs/2301.03580
MIT License
1.4k stars 82 forks source link

SparK ResNet and global feature interaction #80

Open csvance opened 5 months ago

csvance commented 5 months ago

Hello, thanks for the great paper.

With the ResNet version of SparK using sparse convolution and sparse batch normalization together, the flow and mixing of global semantic information is heavily restricted due to effective masking on the receptive field caused by sparse operations and lack of global channel interaction with batch norm. It seems like this information will struggle to propagate especially in more shallow networks with lower receptive field like ResNet50. In the paper it was empirically shown that ResNet50 benefited the least from SparK, failing to match the performance of supervised ResNet101. I was wonder if the authors or anyone else tried using sparse group normalization with ResNet so there would be some global interaction of feature channels to better allow the learning of high level features. Masked autoencoder pretraining has shown alot of promise for data limited tasks in medical imaging and ResNet50 is commonly used by practitioners, so understanding how to most effectively use SparK pretraining has big implications for many in the field.

keyu-tian commented 5 months ago

@csvance very insightful thinking. I've also heard before that using a 3D sparse convolutional backbone network can lead to insufficient global information interaction in 3D point cloud perception (actually the interaction only occurs within connected components).

Yeah GroupNorm, LayerNorm or some attention-like operator can alleviate this problem. It's a promising direction to explore.

csvance commented 5 months ago

Running some experiments with this on an internal dataset using the Big Transfer ResNetV2 architecture. One of the other reasons I think GroupNorm might be promising was its transfer learning performance as demonstrated in the Big Transfer paper. Even though sparse normalization counteracts some of the distribution shift, there is going to be a higher degree of feature interaction with unmasked input. Group norm could possibly be more robust against this than batch norm for pretraining -> training. If it shows promise I will do an ImageNet run and post the results here.

csvance commented 4 months ago

Hello, I know the paper says you use a batch size of 4096, but was curious how many GPU that was split between? Having some stability issues and I suspect it has to do with effective batch size for batch norm in the decoder. Previously I was using a batch size of 64 and accumulating gradient 64 times on single RTX 3090 24GB to get 4096. Now I have access to 4x A6000 48GB and am trying batch size 128 + gradient accumulation 8 to get 4096 and using sync batch norm same as SparK decoder. Hoping that having a much higher effective batch size for batch norm in decoder will be the key to stop training from diverging.

keyu-tian commented 4 months ago

@csvance Yeah the sync batch norm and a big enough batch size are important for BN stability. We used 32 Tesla A100s, bs=128 per GPU (so total bs will be 4096) in most of the time, and didn't use the gradient accumulation. I think bs=64 is too small for BN, and 4x128=512 can be better.

csvance commented 4 months ago

Yeah I'm definitely seeing a big difference between my new and old setup. There is still some instability with 4*128 effective batch size for sync batch norm, but things converge much better than I have seen before. It looks like BatchNorm + large batch size is crucial for the decoder here, I have tried decoder with GroupNorm and convergence is significantly worse without any improvement to stability.

Just as an experiment I'm running with an image size of 128x128 and using a batch size of 512 per GPU giving me 2048 sync BN batch size (accumulate gradient twice to get 4096 for optimizer step). Will be interesting to see if there is still issues with constant gradient explosion. Here is what the divergence looks like in the loss curve, it pretty much always happens when I reach a certain loss around 0.3 MSE or so. Doesn't matter even when I fine tuning gradient clipping, learning rate etc, it's like the loss landscape is extremely sharp / unstable without sufficient batch size for batch norm.

image

csvance commented 4 months ago

I was able to get SparK to converge with LayerNorm in the decoder instead of BatchNorm! I had forgot to enable decoupled weight decay with the optimizer I was using, which was the source of the divergences (too much weight decay relative to learning rate). Still during training there are some times where the loss spikes a bit, but its not extreme and starts to decrease again to a better minima.

I have no doubt that BatchNorm will converge faster still, but using LayerNorm in decoder could be a good option for those who do not have access to a huge number of GPU.

keyu-tian commented 4 months ago

@csvance Happy to hear that! and thanks for your effort. Substituting BN with LN or GN (groupnorm) is indeed a valuable try, and I guess that BN isn't always essential. We initially adopted BN just because UNet used it, but I believe LN or GN could effectively replace BN without a lot of performance drop, and yes, this could be particularly beneficial for those with limited GPU resources.

csvance commented 4 months ago

For using SparK with backscatter X-ray images, I found it was good to use a larger epsilon for tile normalization and also normalize x_hat tiles. Reason for this is there is many tiles which are mostly background since alot of X-ray are taller than they are wide and often have large segments of noisy background. This made the learned representation transfer better for downstream problems. Without the large epsilon, training is unstable when normalizing x_hat tiles at the start of training which seems to negatively impact the learned representation. I suspect normalizing x_hat is a useful inductive bias, but I havn't tried any of this with ImageNet yet.

Until now I have been working with a relatively small subset of my dataset, roughly ~100k. Going to ramp things up several order of magnitude now. Results on downstream tasks are very promising even with such few images. Downstream is already close to ImageNet21K transfer performance.