Closed luluenen closed 2 years ago
Hi,
I'm planning to release at least ResNet and ViT for CIFAR-100. I'll comment on this issue after I release those pretrained models. I recommend using timm for ImageNet-1K (e.g., please refer to fourier_analysis.ipynb
).
I just released pretrained models for CIFAR-100 (ResNet-50, ViT-Ti, PiT-Ti, and Swin-Ti).
The codes below are snippets for (a) loading pretrained models and (b) converting them into block sequences (optional). Please refer to featuremap_variance.ipynb
(Colab notebook) for example code.
# ResNet-50
import models
# a. download and load a pretrained model for CIFAR-100
url = "https://github.com/xxxnell/how-do-vits-work-storage/releases/download/v0.1/resnet_50_cifar100_691cc9a9e4.pth.tar"
path = "checkpoints/resnet_50_cifar100_691cc9a9e4.pth.tar"
models.download(url=url, path=path)
name = "resnet_50"
model = models.get_model(name, num_classes=100, # timm does not provide a ResNet for CIFAR
stem=model_args.get("stem", False))
map_location = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(path, map_location=map_location)
model.load_state_dict(checkpoint["state_dict"])
# b. model → blocks. `blocks` is a sequence of blocks
blocks = [
model.layer0,
*model.layer1,
*model.layer2,
*model.layer3,
*model.layer4,
model.classifier,
]
# ViT-Ti
import copy
import timm
import torch
import torch.nn as nn
import models
# a. download and load a pretrained model for CIFAR-100
url = "https://github.com/xxxnell/how-do-vits-work-storage/releases/download/v0.1/vit_ti_cifar100_9857b21357.pth.tar"
path = "checkpoints/vit_ti_cifar100_9857b21357.pth.tar"
models.download(url=url, path=path)
model = timm.models.vision_transformer.VisionTransformer(
num_classes=100, img_size=32, patch_size=2, # for CIFAR
embed_dim=192, depth=12, num_heads=3, qkv_bias=False, # for ViT-Ti
)
model.name = "vit_ti"
models.stats(model)
map_location = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(path, map_location=map_location)
model.load_state_dict(checkpoint["state_dict"])
# b. model → blocks. `blocks` is a sequence of blocks
class PatchEmbed(nn.Module):
def __init__(self, model):
super().__init__()
self.model = copy.deepcopy(model)
def forward(self, x, **kwargs):
x = self.model.patch_embed(x)
cls_token = self.model.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
x = self.model.pos_drop(x + self.model.pos_embed)
return x
class Residual(nn.Module):
def __init__(self, *fn):
super().__init__()
self.fn = nn.Sequential(*fn)
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x
class Lambda(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
return self.fn(x)
def flatten(xs_list):
return [x for xs in xs_list for x in xs]
# model → blocks. `blocks` is a sequence of blocks
blocks = [
PatchEmbed(model),
*flatten([[Residual(b.norm1, b.attn), Residual(b.norm2, b.mlp)]
for b in model.blocks]),
nn.Sequential(model.norm, Lambda(lambda x: x[:, 0]), model.head),
]
# PiT-Ti
import copy
import math
import timm
import torch
import torch.nn as nn
# a. download and load a pretrained model for CIFAR-100
url = "https://github.com/xxxnell/how-do-vits-work-storage/releases/download/v0.1/pit_ti_cifar100_0645889efb.pth.tar"
path = "checkpoints/pit_ti_cifar100_0645889efb.pth.tar"
models.download(url=url, path=path)
model = timm.models.pit.PoolingVisionTransformer(
num_classes=100, img_size=32, patch_size=2, stride=1, # for CIFAR-100
base_dims=[32, 32, 32], depth=[2, 6, 4], heads=[2, 4, 8], mlp_ratio=4, # for PiT-Ti
)
model.name = "pit_ti"
models.stats(model)
map_location = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(path, map_location=map_location)
model.load_state_dict(checkpoint["state_dict"])
# b. model → blocks. `blocks` is a sequence of blocks
class PatchEmbed(nn.Module):
def __init__(self, model):
super().__init__()
self.model = copy.deepcopy(model)
def forward(self, x, **kwargs):
x = self.model.patch_embed(x)
x = self.model.pos_drop(x + self.model.pos_embed)
cls_tokens = self.model.cls_token.expand(x.shape[0], -1, -1)
return (x, cls_tokens)
class Concat(nn.Module):
def __init__(self, model):
super().__init__()
self.model = copy.deepcopy(model)
def forward(self, x, **kwargs):
x, cls_tokens = x
B, C, H, W = x.shape
token_length = cls_tokens.shape[1]
x = x.flatten(2).transpose(1, 2)
x = torch.cat((cls_tokens, x), dim=1)
return x
class Pool(nn.Module):
def __init__(self, block, token_length):
super().__init__()
self.block = copy.deepcopy(block)
self.token_length = token_length
def forward(self, x, **kwargs):
cls_tokens = x[:, :self.token_length]
x = x[:, self.token_length:]
B, N, C = x.shape
H, W = int(math.sqrt(N)), int(math.sqrt(N))
x = x.transpose(1, 2).reshape(B, C, H, W)
x, cls_tokens = self.block(x, cls_tokens)
return x, cls_tokens
class Classifier(nn.Module):
def __init__(self, norm, head):
super().__init__()
self.head = copy.deepcopy(head)
self.norm = copy.deepcopy(norm)
def forward(self, x, **kwargs):
x = x[:,0]
x = self.norm(x)
x = self.head(x)
return x
class Residual(nn.Module):
def __init__(self, *fn):
super().__init__()
self.fn = nn.Sequential(*fn)
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x
def flatten(xs_list):
return [x for xs in xs_list for x in xs]
blocks = [
nn.Sequential(PatchEmbed(model), Concat(model),),
*flatten([[Residual(b.norm1, b.attn), Residual(b.norm2, b.mlp)]
for b in model.transformers[0].blocks]),
nn.Sequential(Pool(model.transformers[0].pool, 1), Concat(model),),
*flatten([[Residual(b.norm1, b.attn), Residual(b.norm2, b.mlp)]
for b in model.transformers[1].blocks]),
nn.Sequential(Pool(model.transformers[1].pool, 1), Concat(model),),
*flatten([[Residual(b.norm1, b.attn), Residual(b.norm2, b.mlp)]
for b in model.transformers[2].blocks]),
Classifier(model.norm, model.head),
]
# Swin-Ti
import copy
import timm
import models
import torch
import torch.nn as nn
# a. download and load a pretrained model for CIFAR-100
url = "https://github.com/xxxnell/how-do-vits-work-storage/releases/download/v0.1/swin_ti_cifar100_ec2894492b.pth.tar"
path = "checkpoints/swin_ti_cifar100_ec2894492b.pth.tar"
models.download(url=url, path=path)
model = timm.models.swin_transformer.SwinTransformer(
num_classes=100, img_size=32, patch_size=1, window_size=4, # for CIFAR-100
embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), qkv_bias=False, # for Swin-Ti
)
model.name = "swin_ti"
models.stats(model)
map_location = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(path, map_location=map_location)
model.load_state_dict(checkpoint["state_dict"])
# b. model → blocks. `blocks` is a sequence of blocks
class Attn(nn.Module):
def __init__(self, block):
super().__init__()
self.block = copy.deepcopy(block)
self.block.mlp = nn.Identity()
self.block.norm2 = nn.Identity()
def forward(self, x, **kwargs):
x = self.block(x)
x = x / 2
return x
class MLP(nn.Module):
def __init__(self, block):
super().__init__()
block = copy.deepcopy(block)
self.mlp = block.mlp
self.norm2 = block.norm2
def forward(self, x, **kwargs):
x = x + self.mlp(self.norm2(x))
return x
class Classifier(nn.Module):
def __init__(self, norm, head):
super().__init__()
self.norm = copy.deepcopy(norm)
self.head = copy.deepcopy(head)
def forward(self, x, **kwargs):
x = self.norm(x)
x = x.mean(dim=1)
x = self.head(x)
return x
def flatten(xs_list):
return [x for xs in xs_list for x in xs]
blocks = [
model.patch_embed,
*flatten([[Attn(block), MLP(block)] for block in model.layers[0].blocks]),
model.layers[0].downsample,
*flatten([[Attn(block), MLP(block)] for block in model.layers[1].blocks]),
model.layers[1].downsample,
*flatten([[Attn(block), MLP(block)] for block in model.layers[2].blocks]),
model.layers[2].downsample,
*flatten([[Attn(block), MLP(block)] for block in model.layers[3].blocks]),
Classifier(model.norm, model.head)
]
Please feel free to reopen this issue if you have any problems.
Thanks very much !!!
Could you let accessible for the already trained models in this work ? thank you very much in advance