PyTorch reimplementation of "FlexiViT: One Model for All Patch Sizes".
pip install flexivit-pytorch
Or install the entire repo with:
git clone https://github.com/bwconrad/flexivit
cd flexivit/
pip install -r requirements.txt
import torch
from flexivit_pytorch import FlexiVisionTransformer
net = FlexiVisionTransformer(
img_size=240,
base_patch_size=32,
patch_size_seq=(8, 10, 12, 15, 16, 20, 14, 30, 40, 48),
base_pos_embed_size=7,
num_classes=1000,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
)
img = torch.randn(1, 3, 240, 240)
preds = net(img)
You can also initialize default network configurations:
from flexivit_pytorch import (flexivit_base, flexivit_huge, flexivit_large,
flexivit_small, flexivit_tiny)
net = flexivit_tiny()
net = flexivit_small()
net = flexivit_base()
net = flexivit_large()
net = flexivit_huge()
The patch embedding layer of a standard pretrained vision transformer can be resized to any patch size using the pi_resize_patch_embed()
function. A example
doing this with the timm
library is the following:
from timm import create_model
from timm.layers.pos_embed import resample_abs_pos_embed
from flexivit_pytorch import pi_resize_patch_embed
# Load the pretrained model's state_dict
state_dict = create_model("vit_base_patch16_224", pretrained=True).state_dict()
# Resize the patch embedding
new_patch_size = (32, 32)
state_dict["patch_embed.proj.weight"] = pi_resize_patch_embed(
patch_embed=state_dict["patch_embed.proj.weight"], new_patch_size=new_patch_size
)
# Interpolate the position embedding size
image_size = 224
grid_size = image_size // new_patch_size[0]
state_dict["pos_embed"] = resample_abs_pos_embed(
posemb=state_dict["pos_embed"], new_size=[grid_size, grid_size]
)
# Load the new weights into a model with the target image and patch sizes
net = create_model(
"vit_base_patch16_224", img_size=image_size, patch_size=new_patch_size
)
net.load_state_dict(state_dict, strict=True)
convert_patch_embed.py
can similarity do the resizing on any local model checkpoint file. For example, to resize to a patch size of 20:
python convert_patch_embed.py -i vit-16.pt -o vit-20.pt -n patch_embed.proj.weight -ps 20
or to a patch size of height 10 and width 15:
python convert_patch_embed.py -i vit-16.pt -o vit-10-15.pt -n patch_embed.proj.weight -ps 10 15
-n
argument should correspond to the name of the patch embedding weights in the checkpoint's state dict.eval.py
can be used to evaluate pretrained Vision Transformer models at different patch sizes. For example, to evaluate a ViT-B/16 at a patch size of 20 on the ImageNet-1k validation set, you can run:
python eval.py --accelerator gpu --devices 1 --precision 16 --model.resize_type pi
--model.weights vit_base_patch16_224.augreg_in21k_ft_in1k --data.root path/to/val/data/
--data.num_classes 1000 --model.patch_size 20 --data.size 224 --data.crop_pct 0.9
--data.mean "[0.5,0.5,0.5]" --data.std "[0.5,0.5,0.5]" --data.batch_size 256
--model.weights
should correspond to a timm
model name.--data.root
directory should be organized in the TorchVision ImageFolder structure. Alternatively, an LMDB file can be used by setting --data.is_lmdb True
and having --data.root
point to the .lmdb
file.timm
's baseline results, make sure that
--data.size
, --data.crop_pct
, --data.interpolation
(all listed here), --data.mean
, and --data.std
(in general found here) are correct for the model. --data.mean imagenet
and --data.mean clip
can be set to use the respective default values (same for --data.std
).python eval.py --help
for a list and descriptions for all arguments.The following experiments test using PI-resizing to change the patch size of standard ViT models during evaluation. All models have been fine-tuned on ImageNet-1k with a fixed patch size and are evaluated with different patch sizes.
@article{beyer2022flexivit,
title={FlexiViT: One Model for All Patch Sizes},
author={Beyer, Lucas and Izmailov, Pavel and Kolesnikov, Alexander and Caron, Mathilde and Kornblith, Simon and Zhai, Xiaohua and Minderer, Matthias and Tschannen, Michael and Alabdulmohsin, Ibrahim and Pavetic, Filip},
journal={arXiv preprint arXiv:2212.08013},
year={2022}
}