keyu-tian / SparK

[ICLR'23 Spotlight🔥] The first successful BERT/MAE-style pretraining on any convolutional network; Pytorch impl. of "Designing BERT for Convolutional Networks: Sparse and Hierarchical Masked Modeling"
https://arxiv.org/abs/2301.03580
MIT License
1.41k stars 82 forks source link

ConvNext implementation performance #81

Closed ds2268 closed 6 months ago

ds2268 commented 6 months ago

I am using a ConvNext backbone in some downstream tasks. Originally, I used the torch's default implementation.

I have then used SparK to (SSL) pretrain the ConvNext backbone to (hopefully) improve the downstream performance on medical images.

I have replaced the torch ConvNext implementation with yours (original Meta implementation + sparse convolutions) to enable the usage of SparK pretrained ConvNext models (weights are not compatible). I noticed a 3-10% worse performance on the downstream task when I initialized your ConvNext implementation with ImageNet (not SparK SSL pretrained) in comparison with the torch original ImageNet initialized ConvNext implementation.

@keyu-tian do you have any idea why your convnext implementation performs much worse out of the box?

ds2268 commented 6 months ago

The only difference that I see between the implementation:

Is the LayerNorm which is implemented differently in Torch and uses learned weights and biases. The official implementation uses channel_first implementation, where weights and biases are not learned, in comparison with torch implementation which by default uses the channel_last implementation with learned w and b - elementwise_affine=True.

I don't know if this makes the difference, but also randomly initialized torch and official implementations perform differently, so the pre-trained ImageNet weights are not a problem.

keyu-tian commented 6 months ago

@ds2268 There is no elementwise_affine in the official implementation, but the weights and biases are still hardcoded to be learnable (they are both nn. parameters). And the reason for using a channel-first LN is that this avoids a reshape operation, but the computation results should be the same.

Have you used sparse conv during training? I mean, it should only be used in the SSL phase.

ds2268 commented 6 months ago

Torch implementation uses stochastic depth (drop_path in the official implementation):

https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py#L380

The official implementation defaults to 0 for the drop_path, if you don't specify it. I think that this caused the degradation of the performance.

keyu-tian commented 6 months ago

That explains.