Closed Theivaprakasham closed 2 years ago
Both are issues on timm side. Feel free to open issue at timm repo. Here is a explanation for the mentioned issues:
Main problem is that when you use pooler CatAvgMax, it creates a feature which is concatenation of average pooled and max pooled features that results in a 2 x D (where D is number of channels before pooling) dimensional final feature. However, proceeding LayerNorm layer doesn't have D=2048 , instead has D=1024:
Sequential(
(global_pool): SelectAdaptivePool2d (pool_type=catavgmax, flatten=Identity())
(norm): LayerNorm2d((1024,), eps=1e-06, elementwise_affine=True)
(flatten): Flatten(start_dim=1, end_dim=-1)
(drop): Dropout(p=0.0, inplace=False)
(fc): Identity()
)
There are 2 workarounds you can apply;
1) Use original pooling that was used for training ConvNext model.
encoder = create_encoder("convnext_base_in22ft1k", n_in=3, pretrained=False, pool_type=None)
model = create_simclr_model(encoder, hidden_size=2048, projection_size=128, nlayers=2)
2) Replace LayerNorm(1024) layer with LayerNorm(2048)
from timm.models.layers import LayerNorm2d
encoder = create_encoder("convnext_base_in22ft1k", n_in=3, pretrained=False, pool_type=PoolingType.CatAvgMax)
norm = LayerNorm2d(2048)
norm.weight.data = torch.cat([encoder.head.norm.weight.data, encoder.head.norm.weight.data])
encoder.head.norm = norm
As for swin transformer, only avg pooling is supported and it is hardcoded in forward method, so no need to pass a pool_type:
>>> encoder.forward_features??
def forward_features(self, x):
x = self.patch_embed(x)
if self.absolute_pos_embed is not None:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
x = self.layers(x)
x = self.norm(x) # B L C
x = self.avgpool(x.transpose(1, 2)) # B C 1
x = torch.flatten(x, 1)
return x
Version fastai==2.5.3 fastcore==1.3.29 kornia==0.6.3 timm==0.5.4
Describe the bug The tutorials of SimCLR v1 & SimCLR v2, MoCo v1 & MoCo v2, BYOL, SwAV, Barlow Twins and DINO doesn't support Vision Transformer based or ConvNeXt based timm models as encoders.
Attached below is the sample stack trace of the same.
To Reproduce Steps to reproduce the behavior:
Error with full stack trace
2.
Error with full stack trace
Expected behavior A clear and concise description of what you expected to happen.
Additional context Add any other context about the problem here.