huggingface / pytorch-image-models

The largest collection of PyTorch image encoders / backbones. Including train, eval, inference, export scripts, and pretrained weights -- ResNet, ResNeXT, EfficientNet, NFNet, Vision Transformer (ViT), MobileNetV4, MobileNet-V3 & V2, RegNet, DPN, CSPNet, Swin Transformer, MaxViT, CoAtNet, ConvNeXt, and more
https://huggingface.co/docs/timm
Apache License 2.0
31.95k stars 4.73k forks source link

Properly get feature maps from Swin-V2 instead of norm #1455

Closed sarmientoj24 closed 1 year ago

sarmientoj24 commented 2 years ago

Swin-V2 as a backbone gives the feature maps without norm. How to get it?

sarmientoj24 commented 2 years ago

I tried using a different repo of Swin-V2 and I am getting these four feature maps

(torch.Size([1, 96, 64, 64]),
 torch.Size([1, 192, 32, 32]),
 torch.Size([1, 384, 16, 16]),
 torch.Size([1, 768, 8, 8]))

I tried editing yours to be like this

    def forward_features(self, x):
        x = self.patch_embed(x)
        if self.absolute_pos_embed is not None:
            x = x + self.absolute_pos_embed
        x = self.pos_drop(x)

        outputs = []
        print(len(self.layers))
        for layer in self.layers:
            x = layer(x)
            print(x.shape)
            outputs.append(x)

        # x = self.norm(x)  # B L C
        return outputs

model = timm.create_model('swinv2_tiny_window8_256', img_size=(256, 256), num_classes=2, pretrained=True)
sample = model.forward_features(torch.randn(1, 3, 256, 256))

but i get these

torch.Size([1, 1024, 192])
torch.Size([1, 256, 384])
torch.Size([1, 64, 768])
torch.Size([1, 64, 768])
rwightman commented 2 years ago

@sarmientoj24 unfortunately, swin v1 and v2 (adapted from the official microsoft modelling code) put the downsample at the end of blocks, so you can't simply take the output of each block ... in their code for obj det//segmentation, they actually have to modify the block to output both feat maps, which is silly, you can just reorganize the stages to have downsample first... https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/master/mmdet/models/backbones/swin_transformer.py#L362-L402

I actually fixed this in my swin v2 implementation (before the original came out) https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/swin_transformer_v2_cr.py#L597-L616, in theory I could make that implementation work with either v1 or offficial v2 checkpoints (remap them on load)...

My swinv2_cr impl, and now recent maxvit, coatnet, gcvit models would be possible to enable features_only support pretty easily as I organized the models appropriately.

sarmientoj24 commented 2 years ago

Does that mean I should just use swinv2_cr instead?

rwightman commented 2 years ago

@sarmientoj24 it'd be worth testing your use case with swinv2_cr_small_ns_224 to see if it works better, if that's the case I could prioritize making the other v1/v2 models available through that impl...

sarmientoj24 commented 2 years ago

@rwightman

model = timm.create_model('swinv2_cr_small_ns_224', img_size=(224, 224), num_classes=2, features_only=True, pretrained=True)

RuntimeError: features_only not implemented for Vision Transformer models.

when i remove features only...

sample = model.forward_features(torch.randn(2, 3, 224, 224))
sample.shape

torch.Size([2, 768, 7, 7])

using forward

sample = model.forward(torch.randn(2, 3, 224, 224))
sample.shape

torch.Size([2, 2])
sarmientoj24 commented 2 years ago

seems like the cr version is from Christopher's which I have already used before. That one has no sequential self-attention, right? and this too?

rwightman commented 2 years ago

@sarmientoj24 it was based on his impl but ended up with quite a few changes, I didn't incl the sequential attn as I wasn't convinced it was working properly...

The forward_features isn't added yet, what I meant was you should be able to use your modifications made for the other swin and you'd get the shapes you expect from the blocks, if that worked as you expect, would be a good signal for me to add full support...

sarmientoj24 commented 2 years ago

@rwightman is the sequential attention a default on microsoft's original Swin-V2?

rwightman commented 2 years ago

@sarmientoj24 no, they did not release it or any of the 'really' big models that use it like giant...

Bailey-24 commented 1 year ago

RuntimeError: Unknown model (swinv2_tiny_window8_256)

why and how to fix it?

rwightman commented 1 year ago

as per #1438 ... feat extraction for swin v1 & v2 is supported now, in NHWC format (v2_cr models support NCHW like other convnets but slight performance penalty to do that for the other v1/v2 setup so didn't permute by default.