Closed ds2268 closed 9 months ago
The only difference that I see between the implementation:
Is the LayerNorm which is implemented differently in Torch and uses learned weights and biases. The official implementation uses channel_first implementation, where weights and biases are not learned, in comparison with torch implementation which by default uses the channel_last implementation with learned w and b - elementwise_affine=True.
I don't know if this makes the difference, but also randomly initialized torch and official implementations perform differently, so the pre-trained ImageNet weights are not a problem.
@ds2268 There is no elementwise_affine in the official implementation, but the weights and biases are still hardcoded to be learnable (they are both nn. parameters). And the reason for using a channel-first LN is that this avoids a reshape operation, but the computation results should be the same.
Have you used sparse conv during training? I mean, it should only be used in the SSL phase.
Torch implementation uses stochastic depth (drop_path in the official implementation):
https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py#L380
The official implementation defaults to 0 for the drop_path, if you don't specify it. I think that this caused the degradation of the performance.
That explains.
I am using a ConvNext backbone in some downstream tasks. Originally, I used the torch's default implementation.
I have then used SparK to (SSL) pretrain the ConvNext backbone to (hopefully) improve the downstream performance on medical images.
I have replaced the torch ConvNext implementation with yours (original Meta implementation + sparse convolutions) to enable the usage of SparK pretrained ConvNext models (weights are not compatible). I noticed a 3-10% worse performance on the downstream task when I initialized your ConvNext implementation with ImageNet (not SparK SSL pretrained) in comparison with the torch original ImageNet initialized ConvNext implementation.
@keyu-tian do you have any idea why your convnext implementation performs much worse out of the box?