ChristophReich1996 / Swin-Transformer-V2

PyTorch reimplementation of the paper "Swin Transformer V2: Scaling Up Capacity and Resolution" [CVPR 2022].
https://arxiv.org/abs/2111.09883
MIT License
173 stars 14 forks source link

About Checkpoints #9

Closed WY-2022 closed 2 years ago

WY-2022 commented 2 years ago

Hi! I have another question. If I just pip, and then :

class SWIN(nn.Module):
     def __init__(self, num_classes=4):
        super().__init__()
        self.num_classes = num_classes
        # self.pool = nn.MaxPool2d(2, 2)
        self.encoder: SwinTransformerV2 = swin_transformer_v2_t(in_channels=3,
                                                            window_size=8,
                                                            input_resolution=(1024, 1280),
                                                            sequential_self_attention=False,
                                                            use_checkpoint=True)
        self.p=self.encoder.patch_embedding
        self.encoder0 = self.encoder.stages[0]
        ... ...

How to use the checkpoint now? And Is there a pre-trained model for v2_base? (And when I just run like above, a wired problem arises: 'warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")')

ChristophReich1996 commented 2 years ago

For loading the provided checkpoints you need to initialize the network in the training configuration, load the state dict, and then change the resolution/window size for your need. Here an example for the CIFAR10 checkpoint:

import torch
from swin_transformer_v2 import swin_transformer_v2_b, SwinTransformerV2

swin_transformer: SwinTransformerV2 = swin_transformer_v2_t(input_resolution=(32, 32),
                                                                window_size=8,
                                                                sequential_self_attention=False,
                                                                use_checkpoint=True)
swin_transformer.load_state_dict(torch.load("path_to_weights/cifar10_swin_t_best_model_backbone.pt"))
swin_transformer.update_resolution(new_window_size=8, new_input_resolution=(1024, 1280))

Here an example for the Places365 dataset:

import torch
from swin_transformer_v2 import swin_transformer_v2_b, SwinTransformerV2

swin_transformer: SwinTransformerV2 = swin_transformer_v2_b(input_resolution=(256, 256),
                                                                window_size=8,
                                                                sequential_self_attention=False,
                                                                use_checkpoint=True)
swin_transformer.load_state_dict(torch.load("path_to_weights/places365_swin_b_best_model_backbone.pt"))
swin_transformer.update_resolution(new_window_size=8, new_input_resolution=(1024, 1280))

The CIFAR10 checkpoint is for the tiny model and the Places365 checkpoint is for the base model.

Please note that there are pre-trained weights on ImageNet1k available in the Timm library!