i am trying to load the pretrained model for imagenet 1k in kaggle to interact with it but the performance iam getting is random at best
any help is much appreciated
data set required in kaggle : imagenet-1k-resized-256
i coped the relevant pieces from the eval script
the code as follows it takes a min to run mostly the downlad
pretrained_dict = {k.replace('module.', ''): v for k, v in pretrained_dict.items()}
pretrained_dict = {k.replace('backbone.', ''): v for k, v in pretrained_dict.items()}
for k, v in encoder.state_dict().items():
if k not in pretrained_dict:
print(f'key "{k}" could not be found in loaded state dict')
elif pretrained_dict[k].shape != v.shape:
print(f'key "{k}" is of different shape in model and loaded state dict')
pretrained_dict[k] = v
msg = encoder.load_state_dict(pretrained_dict, strict=False)
print(f'loaded pretrained model with msg: {msg}')
print(f'loaded pretrained encoder from epoch: {checkpoint["epoch"]}\n path: {pretrained}')
del checkpoint
return encoder
checkpoint = torch.load("/kaggle/working/jepa/in1k-probe.pth.tar", map_location=torch.device('cpu'))
pretrained_dict = checkpoint['classifier']
pretrained_dict = {k.replace('module.', ''): v for k, v in pretrained_dict.items()}
file_path = "/kaggle/input/imagenet-1k-resized-256/classes.pkl"
with open(file_path, "rb") as f:
classes = pickle.load( f)
for idx ,row in df.sample(n=15).iterrows():
img = (Image.open(BytesIO(row['image']['bytes'])))
outs = classifier(encoder(transform(img).unsqueeze(0).to(device)))
values, indices = torch.topk( outs, 10 )
display(img)
print(f'real {row["label"]}')
for n ,i in enumerate(indices[0]) :
print(f'{i} : class {classes[int(i)]} value {values[0][n]} ')
i am trying to load the pretrained model for imagenet 1k in kaggle to interact with it but the performance iam getting is random at best any help is much appreciated data set required in kaggle : imagenet-1k-resized-256
i coped the relevant pieces from the eval script the code as follows it takes a min to run mostly the downlad
"""
get the repo in cell 1
!git clone https://github.com/facebookresearch/jepa.git import os os.chdir('/kaggle/working/jepa') !pip install .
!wget https://dl.fbaipublicfiles.com/jepa/vitl16/in1k-probe.pth.tar !wget https://dl.fbaipublicfiles.com/jepa/vitl16/vitl16.pth.tar
config for the model i want
import yaml with open('/kaggle/working/jepa/configs/evals/vitl16_in1k.yaml', 'r') as y_file: params = yaml.load(y_file, Loader=yaml.FullLoader)
params['pretrain']['folder'] = '/kaggle/working/jepa' params['pretrain']['checkpoint'] = 'vitl16.pth.tar'
loading the model
import jepa.src.models.vision_transformer as vit import torch
def load_pretrained( encoder, pretrained, checkpoint_key='target_encoder' ): print(f'Loading pretrained model from {pretrained}') checkpoint = torch.load(pretrained, map_location='cpu') try: pretrained_dict = checkpoint[checkpoint_key] except Exception: pretrained_dict = checkpoint['encoder']
def init_model( device, pretrained, model_name, patch_size=16, crop_size=224,
Video specific parameters
): encoder = vit.dictmodel_name if frames_per_clip > 1: def forward_prehook(module, input): input = input[0] # [B, C, H, W] input = input.unsqueeze(2).repeat(1, 1, frames_per_clip, 1, 1) return (input)
args_eval = params
args_pretrain = args_eval.get('pretrain')
checkpoint_key = args_pretrain.get('checkpoint_key', 'target_encoder') model_name = args_pretrain.get('model_name', None) patch_size = args_pretrain.get('patch_size', None) pretrain_folder = args_pretrain.get('folder', None) ckp_fname = args_pretrain.get('checkpoint', None) tag = args_pretrain.get('write_tag', None) use_sdpa = args_pretrain.get('use_sdpa', True) use_SiLU = args_pretrain.get('use_silu', False) tight_SiLU = args_pretrain.get('tight_silu', True) uniform_power = args_pretrain.get('uniform_power', False) pretrained_path = os.path.join(pretrain_folder, ckp_fname)
Optional [for Video model]:
tubelet_size = args_pretrain.get('tubelet_size', 2) frames_per_clip = args_pretrain.get('frames_per_clip', 1)
args_data = args_eval.get('data') resolution = args_data.get('resolution', 224) num_classes = args_data.get('num_classes')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder = init_model( crop_size=resolution, device=device, pretrained=pretrained_path, model_name=model_name, patch_size=patch_size, frames_per_clip=1, tubelet_size=1, uniform_power=uniform_power, checkpoint_key=checkpoint_key, use_SiLU=use_SiLU, tight_SiLU=tight_SiLU, use_sdpa=use_sdpa)
encoder.eval() for p in encoder.parameters(): p.requires_grad = False
print(encoder)
loading the classifier
from jepa.src.models.attentive_pooler import AttentiveClassifier
classifier = AttentiveClassifier( embed_dim=encoder.embed_dim, num_heads=encoder.num_heads, depth=1, num_classes=num_classes ).to(device)
checkpoint = torch.load("/kaggle/working/jepa/in1k-probe.pth.tar", map_location=torch.device('cpu')) pretrained_dict = checkpoint['classifier'] pretrained_dict = {k.replace('module.', ''): v for k, v in pretrained_dict.items()}
print(classifier)
msg = classifier.load_state_dict(pretrained_dict) print(msg)
evaluating
from PIL import Image from io import BytesIO import pickle import os import pandas as pd
import torch from torchvision import transforms
transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), ])
parquet_file_path = "/kaggle/input/imagenet-1k-resized-256/data/train-00001-of-00052-886eb11e764e42fe.parquet" df = pd.read_parquet(parquet_file_path) print(df.shape)
file_path = "/kaggle/input/imagenet-1k-resized-256/classes.pkl" with open(file_path, "rb") as f: classes = pickle.load( f)
for idx ,row in df.sample(n=15).iterrows(): img = (Image.open(BytesIO(row['image']['bytes']))) outs = classifier(encoder(transform(img).unsqueeze(0).to(device))) values, indices = torch.topk( outs, 10 ) display(img) print(f'real {row["label"]}') for n ,i in enumerate(indices[0]) : print(f'{i} : class {classes[int(i)]} value {values[0][n]} ')
"""