soniajoseph / ViT-Prisma

ViT Prisma is a mechanistic interpretability library for Vision Transformers (ViTs).
Other
179 stars 19 forks source link

Test that our pretrained Prisma models from the old repo fit into HookedViT #79

Open soniajoseph opened 8 months ago

soniajoseph commented 8 months ago

Try loading the ImageNet 1k models trained by Yash (in repo documentation) into HookedViT.

If it doesn't load properly, adapt the model / create helper functions to load it in

Also try with dSprites in documentation.

YashVadi commented 7 months ago

previously trained model doesn't load properly in currently with HookedViT.from_pretrained. issue: https://huggingface.co/IamYash/ImageNet-Tiny-AttentionOnly/discussions/2#660e9e9ef460ed4c9f617e71

soniajoseph commented 7 months ago

Add config file using this template to Huggingface https://github.com/soniajoseph/ViT-Prisma/blob/main/src/vit_prisma/configs/HookedViTConfig.py

PraneetNeuro commented 7 months ago

Worked with @themachinefan on the script to convert weights of the model trained with legacy Prisma code. Attaching the code here for reference. The conversion script is not part of the repo because it won't be required anymore.

Models are being converted and uploaded to Huggingface, they are available here.

import torch
from vit_prisma.models.base_vit import HookedViT
from vit_prisma.configs.HookedViTConfig import HookedViTConfig
import einops

def convert_legacy_prisma_weights(
        old_state_dict,
        cfg: HookedViTConfig,
):
    new_state_dict = {}
    new_state_dict["cls_token"] = old_state_dict["cls_token"]
    new_state_dict["pos_embed.W_pos"] = old_state_dict["position_embedding"].squeeze(0)
    new_state_dict["embed.proj.weight"] = old_state_dict["patch_embedding.proj.weight"]
    new_state_dict["embed.proj.bias"] = torch.zeros(cfg.d_model)

    if 'module.' in list(old_state_dict.keys())[0]:
        block_index = 2
    else:
        block_index = 1
    max_block = max([int(key.split(".")[block_index]) for key in old_state_dict.keys() if "blocks" in key])

    for layer in range(max_block + 1):
        layer_key = f"blocks.{layer}"
        new_layer_key = f"blocks.{layer}"

        qkv_weights = old_state_dict[f"{layer_key}.attention.qkv.weight"]
        proj_weights = old_state_dict[f"{layer_key}.attention.proj.weight"]
        proj_bias = old_state_dict[f"{layer_key}.attention.proj.bias"]

        split_size = qkv_weights.shape[0] // 3
        q_w = qkv_weights[:split_size, :]
        k_w = qkv_weights[split_size:2*split_size, :]
        v_w = qkv_weights[2*split_size:, :]
        q_w_reshaped = q_w.reshape(cfg.n_heads, cfg.d_head, -1).permute(0, 2, 1)
        k_w_reshaped = k_w.reshape(cfg.n_heads, cfg.d_head, -1).permute(0, 2, 1)
        v_w_reshaped = v_w.reshape(cfg.n_heads, cfg.d_head, -1).permute(0, 2, 1)
        new_state_dict[f"{new_layer_key}.attn.W_Q"] = q_w_reshaped
        new_state_dict[f"{new_layer_key}.attn.W_K"] = k_w_reshaped
        new_state_dict[f"{new_layer_key}.attn.W_V"] = v_w_reshaped

        new_state_dict[f"{new_layer_key}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head)
        new_state_dict[f"{new_layer_key}.attn.b_K"] = torch.zeros(cfg.n_heads, cfg.d_head)
        new_state_dict[f"{new_layer_key}.attn.b_V"] = torch.zeros(cfg.n_heads, cfg.d_head)

        proj_w_reshaped = einops.rearrange(proj_weights, "d (n h) -> n h d", n = cfg.n_heads, h = cfg.d_head, d=cfg.d_model)
        new_state_dict[f"{new_layer_key}.attn.W_O"] = proj_w_reshaped
        new_state_dict[f"{new_layer_key}.attn.b_O"] = proj_bias

        if not cfg.attn_only:
            mlp_W_in = old_state_dict[f"{layer_key}.mlp.fc1.weight"].T
            new_state_dict[f"{new_layer_key}.mlp.W_in"] = mlp_W_in

            mlp_W_out = old_state_dict[f"{layer_key}.mlp.fc2.weight"].T
            new_state_dict[f"{new_layer_key}.mlp.W_out"] = mlp_W_out

            new_state_dict[f"{new_layer_key}.mlp.b_in"] = old_state_dict[f"{layer_key}.mlp.fc1.bias"]
            new_state_dict[f"{new_layer_key}.mlp.b_out"] = old_state_dict[f"{layer_key}.mlp.fc2.bias"]

    new_state_dict["ln_final.w"] = torch.ones(cfg.d_model)
    new_state_dict["ln_final.b"] = torch.zeros(cfg.d_model)

    h_W = old_state_dict["head.weight"] 
    h_W = einops.rearrange(h_W, "c d -> d c", d = cfg.d_model, c = cfg.n_classes)
    new_state_dict["head.W_H"] = h_W
    new_state_dict["head.b_H"] = old_state_dict["head.bias"]

    return new_state_dict

def convert_weights(checkpoint_path, folder_to_save):

    config = HookedViTConfig(n_layers=1, patch_size=16, d_model = 768, attn_only=True, d_head = 192, d_mlp =  3072, n_classes=1000, return_type="class_logits", normalization_type=None)
    # Important : change the config to match the old model

    model = HookedViT(config)

    old_state_dict = torch.load(checkpoint_path, map_location='cpu')#['model_state_dict']

    # old_state_dict = {k.replace("module.", ""): v for k, v in old_state_dict.items()}

    new_state_dict = convert_legacy_prisma_weights(old_state_dict, config)

    model.load_and_process_state_dict(
            new_state_dict,
            fold_ln=True,
            center_writing_weights=True,
            fold_value_biases=True,
            refactor_factored_attn_matrices=True,
        )

    torch.save(model.state_dict(), f'{folder_to_save}/{checkpoint_path.split("/")[-1]}')

import os 

for model in os.listdir('old_checkpoints'): 
    try:
        convert_weights(os.path.join('old_checkpoints', model), 'converted_checkpoints')
    except Exception as e:
        print(e)
        print(f'Failed for {model}')
        continue