Request to add registers and position embedding interpolation #2

swarajnanda2021 commented 5 months ago

Lol, I had fun with this code. But it wasn't all that suited to my use without the position embedding interpolation and some registers. So I added them. See if you feel like you want to include them into your code:

import math
from functools import partial
from typing import Callable

import torch
import torch.jit
import torch.nn as nn
import torch.utils.checkpoint
from timm.layers import Mlp, PatchDropout, trunc_normal_
from timm.models._manipulate import checkpoint_seq, named_apply
from timm.models.vision_transformer import (Block, _load_weights,

from soft_moe.soft_moe import SoftMoELayerWrapper

class PatchEmbed(nn.Module):
    # converts image into patch embeddings based on total number of non-overlapping crops.
    # For each image containing n patches, there should be n embedding vectors per image, so a n x embedding_vector matrix.    
    def __init__(self,img_size,patch_size,in_channels=3, embed_dim=256):
        self.img_size       = img_size
        self.patch_size     = patch_size
        self.in_channels    = in_channels
        self.n_patches      = (img_size // patch_size)**2
        self.project        = nn.Conv2d(
                                    in_channels     =in_channels,
                                    out_channels    = embed_dim,
                                    kernel_size     = patch_size,
                                    stride          = patch_size,

    def forward(self,x):
        # x has input a tensor of shape B, C, H, W (batch, channel, height, width)

        x = self.project(x)     # Batch X Embedding Dim X sqrt(N_patches) X sqrt(N_patches)
        x = x.flatten(2)        # Batch X Embedding Dim X N_patches
        x = x.transpose(1,2)    # Batch X N_patches X Embedding Dim

        return x

class SoftMoEVisionTransformer(nn.Module):
    """Vision Transformer with Soft Mixture of Experts MLP layers.

    From the paper "From Sparse to Soft Mixtures of Experts"

    Code modified from:

    def __init__(
        num_experts: int = 128,
        slots_per_expert: int = 1,
        moe_layer_index: int | list[int] = 6,
        img_size: int | tuple[int, int] = 224,
        patch_size: int | tuple[int, int] = 16,
        in_chans: int = 3,
        global_pool: str = "token",
        embed_dim: int = 768,
        depth: int = 12,
        num_heads: int = 12,
        mlp_ratio: float = 4.0,
        qkv_bias: bool = True,
        qk_norm: bool = False,
        init_values: float | None = None,
        class_token: bool = True,
        no_embed_class: bool = False,
        pre_norm: bool = False,
        fc_norm: bool | None = None,
        drop_rate: float = 0.0,
        pos_drop_rate: float = 0.0,
        patch_drop_rate: float = 0.0,
        proj_drop_rate: float = 0.0,
        attn_drop_rate: float = 0.0,
        drop_path_rate: float = 0.0,
        weight_init: str = "",
        embed_layer: Callable = PatchEmbed,
        norm_layer: Callable | None = None,
        act_layer: Callable | None = None,
        block_fn: Callable = Block,
        mlp_layer: Callable = Mlp,
            num_experts (int): Number of experts in MoE layers.
            slots_per_expert (int): Number of token slots per expert.
            moe_layer_index (int or list[int]): Block depth indices where MoE layers are used.
                Either an int which denotes where MoE layers are used from to the end, or a list
                of ints denoting the specific blocks (both use 0-indexing).
            img_size (int or tuple[int, int]): Input image size.
            patch_size (int or tuple[int, int]): Patch size.
            in_chans (int): Number of image input channels.
            global_pool (str): Type of global pooling for the final sequence (default: 'token').
            embed_dim (int): Transformer embedding dimension.
            depth (int): Depth of the transformer.
            num_heads (int): Number of attention heads.
            mlp_ratio (float): Ratio of MLP hidden dim to embedding dim.
            qkv_bias (bool): Enable bias for qkv projections if True.
            qk_norm (bool): Enable normalization of query and key in self-attention.
            init_values (float or None): Layer-scale init values (layer-scale enabled if not None).
            class_token (bool): Use a class token.
            no_embed_class (bool): Do not embed class tokens in the patch embedding.
            pre_norm (bool): Apply normalization before self-attention in the transformer block.
            fc_norm (bool or None): Pre-head norm after pool (instead of before).
                If None, enabled when global_pool == 'avg'.
            drop_rate (float): Head dropout rate.
            pos_drop_rate (float): Position embedding dropout rate.
            attn_drop_rate (float): Attention dropout rate.
            drop_path_rate (float): Stochastic depth rate.
            weight_init (str): Weight initialization scheme.
            embed_layer (Callable): Patch embedding layer.
            norm_layer (Callable or None): Normalization layer.
            act_layer (Callable or None): MLP activation layer.
            block_fn (Callable): Transformer block layer.
            mlp_layer (Callable): MLP layer.
        assert global_pool in ("", "avg", "token")
        assert class_token or global_pool != "token"
        use_fc_norm = global_pool == "avg" if fc_norm is None else fc_norm
        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
        act_layer = act_layer or nn.GELU

        self.global_pool = global_pool
        # num_features for consistency with other models
        self.num_features = self.embed_dim = embed_dim
        self.num_prefix_tokens = 1 if class_token else 0 
        self.no_embed_class = no_embed_class
        self.grad_checkpointing = False

        self.patch_embed = embed_layer(
        self.patch_embed.project.bias = None

        num_patches = (img_size//patch_size)**2

        self.cls_token = (
            nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None

        self.numregisters       = 4
        self.registers          = (

        embed_len = (
            num_patches if no_embed_class else num_patches + self.num_prefix_tokens
        self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02)
        self.pos_drop = nn.Dropout(p=pos_drop_rate)
        if patch_drop_rate > 0:
            self.patch_drop = PatchDropout(
            self.patch_drop = nn.Identity()
        self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()

        # Wrap the mlp_layer in a soft-moe wrapper
        self.num_experts = num_experts
        self.slots_per_expert = slots_per_expert

        moe_mlp_layer = partial(

        # Create a list where each index is the mlp layer class to
        # use at that depth
        self.moe_layer_index = moe_layer_index
        if isinstance(moe_layer_index, list):
            # Only the specified layers in moe_layer_index
            assert len(moe_layer_index) > 0
            assert all([0 <= l < depth for l in moe_layer_index])

            mlp_layers_list = [
                moe_mlp_layer if i in moe_layer_index else mlp_layer
                for i in range(depth)
            if moe_layer_index < depth: 
                # All layers including and after moe_layer_index

                mlp_layers_list = [
                    moe_mlp_layer if i >= moe_layer_index else mlp_layer
                    for i in range(depth)
            else: # hack to make all layers mlp
                mlp_layers_list = [
                    for i in range(depth)

        # stochastic depth decay rule
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
        self.blocks = nn.Sequential(
                for i in range(depth)
        self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()

        # Classifier Head
        self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()

        if weight_init != "skip":

    def init_weights(self, mode=""):
        assert mode in ("jax", "jax_nlhb", "moco", "")
        trunc_normal_(self.pos_embed, std=0.02)
        if self.cls_token is not None:
            nn.init.normal_(self.cls_token, std=1e-6)
        if self.registers is not None:
            nn.init.normal_(self.registers, std=1e-6)

    def _init_weights(self, m):
        # this fn left here for compat with downstream users

    def load_pretrained(self, checkpoint_path, prefix=""):
        _load_weights(self, checkpoint_path, prefix)

    def no_weight_decay(self):
        return {"pos_embed", "cls_token", "dist_token", "registers"}

    def group_matcher(self, coarse=False):
        return dict(
            stem=r"^cls_token|pos_embed|patch_embed",  # stem and embed
            blocks=[(r"^blocks\.(\d+)", None), (r"^norm", (99999,))],

    def set_grad_checkpointing(self, enable=True):
        self.grad_checkpointing = enable

    def pos_embedding_interp(self, x, h, w):

        num_patches = x.shape[1] - 1 # because one is a class token
        N = self.pos_embed.shape[1] - 1 # this is the shape the ViT expects

        if num_patches == N: # won't include a check for the image being square
          return self.pos_embed.shape[1] # because no interpolation needs to be done
        # Now we need to do interpolation. So begin by separating class and position tokens
        class_pos_embed   = self.pos_embed[:,0]
        patch_pos_embed   = self.pos_embed[:,1:]
        dim         = x.shape[-1] # patch embedding dimensionality
        w0 = w // self.patch_embed.patch_size
        h0 = h // self.patch_embed.patch_size
        w0, h0 = w0+0.1, h0+0.1 # preventing some division by zero (just in case)

        # Perform the interpolation
        patch_pos_embed = torch.nn.functional.interpolate(
            patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
            scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return, patch_pos_embed), dim=1)

    def _pos_embed(self, x):

        # original timm, JAX, and deit vit impl
        # pos_embed has entry for class token, concat then add
        batches, _, W, H = x.shape # B, C, W, H
        x = self.patch_embed(x)
        x =, -1, -1), x), dim=1)

        x = x + self.pos_embedding_interp(x,H,W) # I changed this else registers does not work

        if self.registers is not None:
            x =
                dim = 1,

        return self.pos_drop(x)

    def forward_features(self, x):
        x = self._pos_embed(x)
        x = self.patch_drop(x)
        x = self.norm_pre(x)
        if self.grad_checkpointing and not torch.jit.is_scripting():
            x = checkpoint_seq(self.blocks, x)
            x = self.blocks(x)
        x = self.norm(x)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        return x[:, 0] # you will only use the class token
swarajnanda2021 commented 5 months ago

You'll also have to reimplement PatchEmbed, or timm throws a fit. So I added mine. I've kept the number of registers fixed.

The above code should help methods like Dino work with this encoder.