Open soniajoseph opened 8 months ago
previously trained model doesn't load properly in currently with HookedViT.from_pretrained
.
issue: https://huggingface.co/IamYash/ImageNet-Tiny-AttentionOnly/discussions/2#660e9e9ef460ed4c9f617e71
Add config file using this template to Huggingface https://github.com/soniajoseph/ViT-Prisma/blob/main/src/vit_prisma/configs/HookedViTConfig.py
Worked with @themachinefan on the script to convert weights of the model trained with legacy Prisma code. Attaching the code here for reference. The conversion script is not part of the repo because it won't be required anymore.
Models are being converted and uploaded to Huggingface, they are available here.
import torch
from vit_prisma.models.base_vit import HookedViT
from vit_prisma.configs.HookedViTConfig import HookedViTConfig
import einops
def convert_legacy_prisma_weights(
old_state_dict,
cfg: HookedViTConfig,
):
new_state_dict = {}
new_state_dict["cls_token"] = old_state_dict["cls_token"]
new_state_dict["pos_embed.W_pos"] = old_state_dict["position_embedding"].squeeze(0)
new_state_dict["embed.proj.weight"] = old_state_dict["patch_embedding.proj.weight"]
new_state_dict["embed.proj.bias"] = torch.zeros(cfg.d_model)
if 'module.' in list(old_state_dict.keys())[0]:
block_index = 2
else:
block_index = 1
max_block = max([int(key.split(".")[block_index]) for key in old_state_dict.keys() if "blocks" in key])
for layer in range(max_block + 1):
layer_key = f"blocks.{layer}"
new_layer_key = f"blocks.{layer}"
qkv_weights = old_state_dict[f"{layer_key}.attention.qkv.weight"]
proj_weights = old_state_dict[f"{layer_key}.attention.proj.weight"]
proj_bias = old_state_dict[f"{layer_key}.attention.proj.bias"]
split_size = qkv_weights.shape[0] // 3
q_w = qkv_weights[:split_size, :]
k_w = qkv_weights[split_size:2*split_size, :]
v_w = qkv_weights[2*split_size:, :]
q_w_reshaped = q_w.reshape(cfg.n_heads, cfg.d_head, -1).permute(0, 2, 1)
k_w_reshaped = k_w.reshape(cfg.n_heads, cfg.d_head, -1).permute(0, 2, 1)
v_w_reshaped = v_w.reshape(cfg.n_heads, cfg.d_head, -1).permute(0, 2, 1)
new_state_dict[f"{new_layer_key}.attn.W_Q"] = q_w_reshaped
new_state_dict[f"{new_layer_key}.attn.W_K"] = k_w_reshaped
new_state_dict[f"{new_layer_key}.attn.W_V"] = v_w_reshaped
new_state_dict[f"{new_layer_key}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head)
new_state_dict[f"{new_layer_key}.attn.b_K"] = torch.zeros(cfg.n_heads, cfg.d_head)
new_state_dict[f"{new_layer_key}.attn.b_V"] = torch.zeros(cfg.n_heads, cfg.d_head)
proj_w_reshaped = einops.rearrange(proj_weights, "d (n h) -> n h d", n = cfg.n_heads, h = cfg.d_head, d=cfg.d_model)
new_state_dict[f"{new_layer_key}.attn.W_O"] = proj_w_reshaped
new_state_dict[f"{new_layer_key}.attn.b_O"] = proj_bias
if not cfg.attn_only:
mlp_W_in = old_state_dict[f"{layer_key}.mlp.fc1.weight"].T
new_state_dict[f"{new_layer_key}.mlp.W_in"] = mlp_W_in
mlp_W_out = old_state_dict[f"{layer_key}.mlp.fc2.weight"].T
new_state_dict[f"{new_layer_key}.mlp.W_out"] = mlp_W_out
new_state_dict[f"{new_layer_key}.mlp.b_in"] = old_state_dict[f"{layer_key}.mlp.fc1.bias"]
new_state_dict[f"{new_layer_key}.mlp.b_out"] = old_state_dict[f"{layer_key}.mlp.fc2.bias"]
new_state_dict["ln_final.w"] = torch.ones(cfg.d_model)
new_state_dict["ln_final.b"] = torch.zeros(cfg.d_model)
h_W = old_state_dict["head.weight"]
h_W = einops.rearrange(h_W, "c d -> d c", d = cfg.d_model, c = cfg.n_classes)
new_state_dict["head.W_H"] = h_W
new_state_dict["head.b_H"] = old_state_dict["head.bias"]
return new_state_dict
def convert_weights(checkpoint_path, folder_to_save):
config = HookedViTConfig(n_layers=1, patch_size=16, d_model = 768, attn_only=True, d_head = 192, d_mlp = 3072, n_classes=1000, return_type="class_logits", normalization_type=None)
# Important : change the config to match the old model
model = HookedViT(config)
old_state_dict = torch.load(checkpoint_path, map_location='cpu')#['model_state_dict']
# old_state_dict = {k.replace("module.", ""): v for k, v in old_state_dict.items()}
new_state_dict = convert_legacy_prisma_weights(old_state_dict, config)
model.load_and_process_state_dict(
new_state_dict,
fold_ln=True,
center_writing_weights=True,
fold_value_biases=True,
refactor_factored_attn_matrices=True,
)
torch.save(model.state_dict(), f'{folder_to_save}/{checkpoint_path.split("/")[-1]}')
import os
for model in os.listdir('old_checkpoints'):
try:
convert_weights(os.path.join('old_checkpoints', model), 'converted_checkpoints')
except Exception as e:
print(e)
print(f'Failed for {model}')
continue
Try loading the ImageNet 1k models trained by Yash (in repo documentation) into HookedViT.
If it doesn't load properly, adapt the model / create helper functions to load it in
Also try with dSprites in documentation.