Closed WY-2022 closed 2 years ago
For loading the provided checkpoints you need to initialize the network in the training configuration, load the state dict, and then change the resolution/window size for your need. Here an example for the CIFAR10 checkpoint:
import torch
from swin_transformer_v2 import swin_transformer_v2_b, SwinTransformerV2
swin_transformer: SwinTransformerV2 = swin_transformer_v2_t(input_resolution=(32, 32),
window_size=8,
sequential_self_attention=False,
use_checkpoint=True)
swin_transformer.load_state_dict(torch.load("path_to_weights/cifar10_swin_t_best_model_backbone.pt"))
swin_transformer.update_resolution(new_window_size=8, new_input_resolution=(1024, 1280))
Here an example for the Places365 dataset:
import torch
from swin_transformer_v2 import swin_transformer_v2_b, SwinTransformerV2
swin_transformer: SwinTransformerV2 = swin_transformer_v2_b(input_resolution=(256, 256),
window_size=8,
sequential_self_attention=False,
use_checkpoint=True)
swin_transformer.load_state_dict(torch.load("path_to_weights/places365_swin_b_best_model_backbone.pt"))
swin_transformer.update_resolution(new_window_size=8, new_input_resolution=(1024, 1280))
The CIFAR10 checkpoint is for the tiny model and the Places365 checkpoint is for the base model.
Please note that there are pre-trained weights on ImageNet1k available in the Timm library!
Hi! I have another question. If I just pip, and then :
How to use the checkpoint now? And Is there a pre-trained model for v2_base? (And when I just run like above, a wired problem arises: 'warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")')