KeremTurgutlu / self_supervised

Implementation of popular SOTA self-supervised learning algorithms as Fastai Callbacks.
Apache License 2.0
318 stars 33 forks source link

Transformer and ConvNext Timm models not supported as Encoder #78

Closed Theivaprakasham closed 2 years ago

Theivaprakasham commented 2 years ago

Version fastai==2.5.3 fastcore==1.3.29 kornia==0.6.3 timm==0.5.4

Describe the bug The tutorials of SimCLR v1 & SimCLR v2, MoCo v1 & MoCo v2, BYOL, SwAV, Barlow Twins and DINO doesn't support Vision Transformer based or ConvNeXt based timm models as encoders.

Attached below is the sample stack trace of the same.

To Reproduce Steps to reproduce the behavior:

arch = "convnext_base_in22ft1k" 
encoder = create_encoder(arch, pretrained=False, n_in=3, pool_type=PoolingType.CatAvgMax) 
model = create_byol_model(encoder)

Error with full stack trace

---------------------------------------------------------------------------

RuntimeError                              Traceback (most recent call last)

[<ipython-input-16-43550aa24362>](https://localhost:8080/#) in <module>()
----> 1 model = create_byol_model(encoder)

7 frames

[/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py](https://localhost:8080/#) in layer_norm(input, normalized_shape, weight, bias, eps)
   2345             layer_norm, (input, weight, bias), input, normalized_shape, weight=weight, bias=bias, eps=eps
   2346         )
-> 2347     return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled)
   2348 
   2349 

RuntimeError: Given normalized_shape=[1024], expected input with shape [*, 1024], but got input of size[2, 1, 1, 2048]

2.

arch = "swin_base_patch4_window7_224" //"vit_base_patch8_224"
encoder = create_encoder(arch, pretrained=False, n_in=3, pool_type=PoolingType.CatAvgMax) 
model = create_byol_model(encoder)

Error with full stack trace

/usr/local/lib/python3.7/dist-packages/torch/functional.py:445: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at  ../aten/src/ATen/native/TensorShape.cpp:2157.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]

---------------------------------------------------------------------------

AssertionError                            Traceback (most recent call last)

[<ipython-input-21-ae77456a6169>](https://localhost:8080/#) in <module>()
      1 arch = "swin_base_patch4_window7_224"
      2 encoder = create_encoder(arch, pretrained=False, n_in=3, pool_type=PoolingType.CatAvgMax)
----> 3 model = create_byol_model(encoder)

6 frames

[/usr/local/lib/python3.7/dist-packages/torch/__init__.py](https://localhost:8080/#) in _assert(condition, message)
    674     if type(condition) is not torch.Tensor and has_torch_function((condition,)):
    675         return handle_torch_function(_assert, (condition,), condition, message)
--> 676     assert condition, message
    677 
    678 ################################################################################

AssertionError: Input image height (128) doesn't match model (224).

Expected behavior A clear and concise description of what you expected to happen.

Additional context Add any other context about the problem here.

KeremTurgutlu commented 2 years ago

Both are issues on timm side. Feel free to open issue at timm repo. Here is a explanation for the mentioned issues:

ConvNext

Main problem is that when you use pooler CatAvgMax, it creates a feature which is concatenation of average pooled and max pooled features that results in a 2 x D (where D is number of channels before pooling) dimensional final feature. However, proceeding LayerNorm layer doesn't have D=2048 , instead has D=1024:

Sequential(
  (global_pool): SelectAdaptivePool2d (pool_type=catavgmax, flatten=Identity())
  (norm): LayerNorm2d((1024,), eps=1e-06, elementwise_affine=True)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (drop): Dropout(p=0.0, inplace=False)
  (fc): Identity()
)

There are 2 workarounds you can apply;

1) Use original pooling that was used for training ConvNext model.

encoder = create_encoder("convnext_base_in22ft1k", n_in=3, pretrained=False, pool_type=None)
model = create_simclr_model(encoder, hidden_size=2048, projection_size=128, nlayers=2)

2) Replace LayerNorm(1024) layer with LayerNorm(2048)

from timm.models.layers import LayerNorm2d
encoder = create_encoder("convnext_base_in22ft1k", n_in=3, pretrained=False, pool_type=PoolingType.CatAvgMax)
norm = LayerNorm2d(2048)
norm.weight.data = torch.cat([encoder.head.norm.weight.data, encoder.head.norm.weight.data])
encoder.head.norm = norm

Swin

As for swin transformer, only avg pooling is supported and it is hardcoded in forward method, so no need to pass a pool_type:

>>> encoder.forward_features??
    def forward_features(self, x):
        x = self.patch_embed(x)
        if self.absolute_pos_embed is not None:
            x = x + self.absolute_pos_embed
        x = self.pos_drop(x)
        x = self.layers(x)
        x = self.norm(x)  # B L C
        x = self.avgpool(x.transpose(1, 2))  # B C 1
        x = torch.flatten(x, 1)
        return x