facebookresearch / jepa

PyTorch code and models for V-JEPA self-supervised learning from video.
Other
2.68k stars 254 forks source link

load the pretrained model in kaggle to interact directly with it #49

Open peternasser99 opened 7 months ago

peternasser99 commented 7 months ago

i am trying to load the pretrained model for imagenet 1k in kaggle to interact with it but the performance iam getting is random at best any help is much appreciated data set required in kaggle : imagenet-1k-resized-256

i coped the relevant pieces from the eval script the code as follows it takes a min to run mostly the downlad

"""

get the repo in cell 1

!git clone https://github.com/facebookresearch/jepa.git import os os.chdir('/kaggle/working/jepa') !pip install .

!wget https://dl.fbaipublicfiles.com/jepa/vitl16/in1k-probe.pth.tar !wget https://dl.fbaipublicfiles.com/jepa/vitl16/vitl16.pth.tar

config for the model i want

import yaml with open('/kaggle/working/jepa/configs/evals/vitl16_in1k.yaml', 'r') as y_file: params = yaml.load(y_file, Loader=yaml.FullLoader)

params['pretrain']['folder'] = '/kaggle/working/jepa' params['pretrain']['checkpoint'] = 'vitl16.pth.tar'

loading the model

import jepa.src.models.vision_transformer as vit import torch

def load_pretrained( encoder, pretrained, checkpoint_key='target_encoder' ): print(f'Loading pretrained model from {pretrained}') checkpoint = torch.load(pretrained, map_location='cpu') try: pretrained_dict = checkpoint[checkpoint_key] except Exception: pretrained_dict = checkpoint['encoder']

pretrained_dict = {k.replace('module.', ''): v for k, v in pretrained_dict.items()}
pretrained_dict = {k.replace('backbone.', ''): v for k, v in pretrained_dict.items()}
for k, v in encoder.state_dict().items():
    if k not in pretrained_dict:
        print(f'key "{k}" could not be found in loaded state dict')
    elif pretrained_dict[k].shape != v.shape:
        print(f'key "{k}" is of different shape in model and loaded state dict')
        pretrained_dict[k] = v
msg = encoder.load_state_dict(pretrained_dict, strict=False)
print(f'loaded pretrained model with msg: {msg}')
print(f'loaded pretrained encoder from epoch: {checkpoint["epoch"]}\n path: {pretrained}')
del checkpoint
return encoder

def init_model( device, pretrained, model_name, patch_size=16, crop_size=224,

Video specific parameters

frames_per_clip=16,
tubelet_size=2,
use_sdpa=False,
use_SiLU=False,
tight_SiLU=True,
uniform_power=False,
checkpoint_key='target_encoder'

): encoder = vit.dictmodel_name if frames_per_clip > 1: def forward_prehook(module, input): input = input[0] # [B, C, H, W] input = input.unsqueeze(2).repeat(1, 1, frames_per_clip, 1, 1) return (input)

    encoder.register_forward_pre_hook(forward_prehook)

encoder.to(device)

encoder = load_pretrained(encoder=encoder, pretrained=pretrained, checkpoint_key=checkpoint_key)
return encoder

args_eval = params

args_pretrain = args_eval.get('pretrain')

checkpoint_key = args_pretrain.get('checkpoint_key', 'target_encoder') model_name = args_pretrain.get('model_name', None) patch_size = args_pretrain.get('patch_size', None) pretrain_folder = args_pretrain.get('folder', None) ckp_fname = args_pretrain.get('checkpoint', None) tag = args_pretrain.get('write_tag', None) use_sdpa = args_pretrain.get('use_sdpa', True) use_SiLU = args_pretrain.get('use_silu', False) tight_SiLU = args_pretrain.get('tight_silu', True) uniform_power = args_pretrain.get('uniform_power', False) pretrained_path = os.path.join(pretrain_folder, ckp_fname)

Optional [for Video model]:

tubelet_size = args_pretrain.get('tubelet_size', 2) frames_per_clip = args_pretrain.get('frames_per_clip', 1)

args_data = args_eval.get('data') resolution = args_data.get('resolution', 224) num_classes = args_data.get('num_classes')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

encoder = init_model( crop_size=resolution, device=device, pretrained=pretrained_path, model_name=model_name, patch_size=patch_size, frames_per_clip=1, tubelet_size=1, uniform_power=uniform_power, checkpoint_key=checkpoint_key, use_SiLU=use_SiLU, tight_SiLU=tight_SiLU, use_sdpa=use_sdpa)

encoder.eval() for p in encoder.parameters(): p.requires_grad = False

print(encoder)

loading the classifier

from jepa.src.models.attentive_pooler import AttentiveClassifier

classifier = AttentiveClassifier( embed_dim=encoder.embed_dim, num_heads=encoder.num_heads, depth=1, num_classes=num_classes ).to(device)

checkpoint = torch.load("/kaggle/working/jepa/in1k-probe.pth.tar", map_location=torch.device('cpu')) pretrained_dict = checkpoint['classifier'] pretrained_dict = {k.replace('module.', ''): v for k, v in pretrained_dict.items()}

print(classifier)

msg = classifier.load_state_dict(pretrained_dict) print(msg)

evaluating

from PIL import Image from io import BytesIO import pickle import os import pandas as pd

import torch from torchvision import transforms

transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), ])

parquet_file_path = "/kaggle/input/imagenet-1k-resized-256/data/train-00001-of-00052-886eb11e764e42fe.parquet" df = pd.read_parquet(parquet_file_path) print(df.shape)

file_path = "/kaggle/input/imagenet-1k-resized-256/classes.pkl" with open(file_path, "rb") as f: classes = pickle.load( f)

for idx ,row in df.sample(n=15).iterrows(): img = (Image.open(BytesIO(row['image']['bytes']))) outs = classifier(encoder(transform(img).unsqueeze(0).to(device))) values, indices = torch.topk( outs, 10 ) display(img) print(f'real {row["label"]}') for n ,i in enumerate(indices[0]) : print(f'{i} : class {classes[int(i)]} value {values[0][n]} ')

"""