microsoft / Swin-Transformer

This is an official implementation for "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows".
https://arxiv.org/abs/2103.14030
MIT License
13.72k stars 2.04k forks source link

Swin Transformer feature map size #210

Open Dreamer312 opened 2 years ago

Dreamer312 commented 2 years ago

Hi, I used a basic Swin Base transformer with your pretrained weight. The input size is (1,3,224,224), and other settings are default I want to extract feature maps from 4 stages, but this is the feature map size I got

output size torch.Size([1, 784, 256]) output size torch.Size([1, 196, 512]) output size torch.Size([1, 49, 1024]) output size torch.Size([1, 49, 1024])

the expected result should be: output size torch.Size([1, 3136, 256]) output size torch.Size([1, 784, 256]) output size torch.Size([1, 196, 512]) output size torch.Size([1, 49, 1024])

here is code add in class SwinTransformer(nn.Module):

def forward_seperate_features(self, x):
        x = self.patch_embed(x)
        if self.absolute_pos_embed is not None:
            x = x + self.absolute_pos_embed
        x = self.pos_drop(x)         
        result = list()

        for layer in self.seperate_layers:
            stages = nn.Sequential(layer)
            if not result:
                output = stages(x)
            else:
                output = stages(result[-1])      
            print(f'output size {output.size()}')
            result.append(output)

        return result 
lixingang commented 2 years ago

The same issue A possible solution is: modify the self.num_layers - 1 into self.num_layers https://github.com/microsoft/Swin-Transformer/blob/78cec9ac5c746c7e72305a9a24716ddb3fcc043c/models/swin_transformer.py#L532

ed-cho commented 7 months ago

Probably too late but that's because BasicLayer does SwinTransformerBlock and then PatchMerging (which downsamples), unlike the diagram in the original paper. In other words, the paper suggests stage 1 (linear embedding + swin transformer block) + stage 2, 3, 4 (patch merging + swin transformer block), but the code is structured as linear embedding + (swin transformer block + patch merging) * 3 + (swin transformer block). They do the same thing, but the intermediate feature map sizes are different.