xxxnell / how-do-vits-work

(ICLR 2022 Spotlight) Official PyTorch implementation of "How Do Vision Transformers Work?"
https://arxiv.org/abs/2202.06709
Apache License 2.0
806 stars 79 forks source link

Trained models #8

Closed luluenen closed 2 years ago

luluenen commented 2 years ago

Could you let accessible for the already trained models in this work ? thank you very much in advance

xxxnell commented 2 years ago

Hi,

I'm planning to release at least ResNet and ViT for CIFAR-100. I'll comment on this issue after I release those pretrained models. I recommend using timm for ImageNet-1K (e.g., please refer to fourier_analysis.ipynb).

xxxnell commented 2 years ago

I just released pretrained models for CIFAR-100 (ResNet-50, ViT-Ti, PiT-Ti, and Swin-Ti).

The codes below are snippets for (a) loading pretrained models and (b) converting them into block sequences (optional). Please refer to featuremap_variance.ipynb (Colab notebook) for example code.

# ResNet-50
import models

# a. download and load a pretrained model for CIFAR-100
url = "https://github.com/xxxnell/how-do-vits-work-storage/releases/download/v0.1/resnet_50_cifar100_691cc9a9e4.pth.tar"
path = "checkpoints/resnet_50_cifar100_691cc9a9e4.pth.tar"
models.download(url=url, path=path)

name = "resnet_50"
model = models.get_model(name, num_classes=100,  # timm does not provide a ResNet for CIFAR
                         stem=model_args.get("stem", False))
map_location = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(path, map_location=map_location)
model.load_state_dict(checkpoint["state_dict"])

# b. model → blocks. `blocks` is a sequence of blocks
blocks = [
    model.layer0,
    *model.layer1,
    *model.layer2,
    *model.layer3,
    *model.layer4,
    model.classifier,
]
# ViT-Ti
import copy
import timm
import torch
import torch.nn as nn
import models

# a. download and load a pretrained model for CIFAR-100
url = "https://github.com/xxxnell/how-do-vits-work-storage/releases/download/v0.1/vit_ti_cifar100_9857b21357.pth.tar"
path = "checkpoints/vit_ti_cifar100_9857b21357.pth.tar"
models.download(url=url, path=path)

model = timm.models.vision_transformer.VisionTransformer(
    num_classes=100, img_size=32, patch_size=2,  # for CIFAR
    embed_dim=192, depth=12, num_heads=3, qkv_bias=False,  # for ViT-Ti 
)
model.name = "vit_ti"
models.stats(model)
map_location = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(path, map_location=map_location)
model.load_state_dict(checkpoint["state_dict"])

# b. model → blocks. `blocks` is a sequence of blocks

class PatchEmbed(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = copy.deepcopy(model)

    def forward(self, x, **kwargs):
        x = self.model.patch_embed(x)
        cls_token = self.model.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        x = self.model.pos_drop(x + self.model.pos_embed)
        return x

class Residual(nn.Module):
    def __init__(self, *fn):
        super().__init__()
        self.fn = nn.Sequential(*fn)

    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

class Lambda(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x):
        return self.fn(x)

def flatten(xs_list):
    return [x for xs in xs_list for x in xs]

# model → blocks. `blocks` is a sequence of blocks
blocks = [
    PatchEmbed(model),
    *flatten([[Residual(b.norm1, b.attn), Residual(b.norm2, b.mlp)] 
              for b in model.blocks]),
    nn.Sequential(model.norm, Lambda(lambda x: x[:, 0]), model.head),
]
# PiT-Ti
import copy
import math
import timm

import torch
import torch.nn as nn

# a. download and load a pretrained model for CIFAR-100
url = "https://github.com/xxxnell/how-do-vits-work-storage/releases/download/v0.1/pit_ti_cifar100_0645889efb.pth.tar"
path = "checkpoints/pit_ti_cifar100_0645889efb.pth.tar"
models.download(url=url, path=path)

model = timm.models.pit.PoolingVisionTransformer(
    num_classes=100, img_size=32, patch_size=2, stride=1,  # for CIFAR-100
    base_dims=[32, 32, 32], depth=[2, 6, 4], heads=[2, 4, 8], mlp_ratio=4,  # for PiT-Ti
)
model.name = "pit_ti"
models.stats(model)
map_location = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(path, map_location=map_location)
model.load_state_dict(checkpoint["state_dict"])

# b. model → blocks. `blocks` is a sequence of blocks

class PatchEmbed(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = copy.deepcopy(model)

    def forward(self, x, **kwargs):
        x = self.model.patch_embed(x)
        x = self.model.pos_drop(x + self.model.pos_embed)
        cls_tokens = self.model.cls_token.expand(x.shape[0], -1, -1)

        return (x, cls_tokens)

class Concat(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = copy.deepcopy(model)

    def forward(self, x, **kwargs):
        x, cls_tokens = x
        B, C, H, W = x.shape
        token_length = cls_tokens.shape[1]

        x = x.flatten(2).transpose(1, 2)
        x = torch.cat((cls_tokens, x), dim=1)

        return x

class Pool(nn.Module):
    def __init__(self, block, token_length):
        super().__init__()
        self.block = copy.deepcopy(block)
        self.token_length = token_length

    def forward(self, x, **kwargs):
        cls_tokens = x[:, :self.token_length]
        x = x[:, self.token_length:]
        B, N, C = x.shape
        H, W = int(math.sqrt(N)), int(math.sqrt(N))
        x = x.transpose(1, 2).reshape(B, C, H, W)

        x, cls_tokens = self.block(x, cls_tokens)

        return x, cls_tokens

class Classifier(nn.Module):
    def __init__(self, norm, head):
        super().__init__()
        self.head = copy.deepcopy(head)
        self.norm = copy.deepcopy(norm)

    def forward(self, x, **kwargs):
        x = x[:,0]
        x = self.norm(x)
        x = self.head(x)
        return x

class Residual(nn.Module):
    def __init__(self, *fn):
        super().__init__()
        self.fn = nn.Sequential(*fn)

    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

def flatten(xs_list):
    return [x for xs in xs_list for x in xs]

blocks = [
    nn.Sequential(PatchEmbed(model), Concat(model),),
    *flatten([[Residual(b.norm1, b.attn), Residual(b.norm2, b.mlp)] 
              for b in model.transformers[0].blocks]),
    nn.Sequential(Pool(model.transformers[0].pool, 1), Concat(model),),
    *flatten([[Residual(b.norm1, b.attn), Residual(b.norm2, b.mlp)] 
              for b in model.transformers[1].blocks]),
    nn.Sequential(Pool(model.transformers[1].pool, 1), Concat(model),),
    *flatten([[Residual(b.norm1, b.attn), Residual(b.norm2, b.mlp)] 
              for b in model.transformers[2].blocks]),
    Classifier(model.norm, model.head),
]
# Swin-Ti
import copy
import timm
import models

import torch
import torch.nn as nn

# a. download and load a pretrained model for CIFAR-100
url = "https://github.com/xxxnell/how-do-vits-work-storage/releases/download/v0.1/swin_ti_cifar100_ec2894492b.pth.tar"
path = "checkpoints/swin_ti_cifar100_ec2894492b.pth.tar"
models.download(url=url, path=path)

model = timm.models.swin_transformer.SwinTransformer(
    num_classes=100, img_size=32, patch_size=1, window_size=4,  # for CIFAR-100
    embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), qkv_bias=False,  # for Swin-Ti
)
model.name = "swin_ti"
models.stats(model)
map_location = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(path, map_location=map_location)
model.load_state_dict(checkpoint["state_dict"])

# b. model → blocks. `blocks` is a sequence of blocks

class Attn(nn.Module):
    def __init__(self, block):
        super().__init__()
        self.block = copy.deepcopy(block)
        self.block.mlp = nn.Identity()
        self.block.norm2 = nn.Identity()

    def forward(self, x, **kwargs):
        x = self.block(x)
        x = x / 2

        return x

class MLP(nn.Module):
    def __init__(self, block):
        super().__init__()
        block = copy.deepcopy(block)
        self.mlp = block.mlp
        self.norm2 = block.norm2

    def forward(self, x, **kwargs):
        x = x + self.mlp(self.norm2(x))

        return x

class Classifier(nn.Module):
    def __init__(self, norm, head):
        super().__init__()
        self.norm = copy.deepcopy(norm)
        self.head = copy.deepcopy(head)

    def forward(self, x, **kwargs):
        x = self.norm(x)
        x = x.mean(dim=1)
        x = self.head(x)

        return x

def flatten(xs_list):
    return [x for xs in xs_list for x in xs]

blocks = [
    model.patch_embed,
    *flatten([[Attn(block), MLP(block)] for block in model.layers[0].blocks]),
    model.layers[0].downsample,
    *flatten([[Attn(block), MLP(block)] for block in model.layers[1].blocks]),
    model.layers[1].downsample,
    *flatten([[Attn(block), MLP(block)] for block in model.layers[2].blocks]),
    model.layers[2].downsample,
    *flatten([[Attn(block), MLP(block)] for block in model.layers[3].blocks]),
    Classifier(model.norm, model.head)
]

Please feel free to reopen this issue if you have any problems.

luluenen commented 2 years ago

Thanks very much !!!