SamsungSAILMontreal / ghn3

Code for "Can We Scale Transformers to Predict Parameters of Diverse ImageNet Models?" [ICML 2023]
https://arxiv.org/abs/2303.04143
MIT License
31 stars 2 forks source link

question regarding evaluation on swin-t #4

Closed sorobedio closed 4 months ago

sorobedio commented 4 months ago

This is the code i use to evaluate ghn3clm16 on cifar-10 but the top1 results was 0.41 and 1.68 for top5

did I do something wrong. here is my test code

`import torch import torchvision from ppuda.config import init_config from ghn3 import from_pretrained, norm_check, Graph, Logger from ppuda.config import init_config from ppuda.utils import infer, AvgrageMeter, adjust_net from ppuda.vision.loader import image_loader

args = init_config(mode='eval', debug=0, arch='resnet50', split='torch') # load arguments from the command line assert args.arch is not None, ('architecture must be specified using, e.g. --arch resnet50', args.arch)

def bn_set_train(module): if isinstance(module, torch.nn.BatchNorm2d): module.track_running_stats = False module.training = True

ghn = from_pretrained(args.ckpt, debug_level=args.debug).to(args.device) # get a pretrained GHN

is_imagenet = args.dataset.startswith('imagenet') is_torch = args.split == 'torch' print('loading the %s dataset...' % args.dataset) val_loader, num_classes = image_loader(args.dataset, args.data_dir, test=True, test_batch_size=args.test_batch_size, num_workers=args.num_workers, noise=args.noise, im_size=224,#args.imsize, seed=args.seed)[1:]

model = eval(f'torchvision.models.{args.arch}()').to(args.device) # create a PyTorch model

if is_torch and not is_imagenet: model = adjust_net(model, large_input=False) # adjust the model for small images such as 32x32 in CIFAR-10

with torch.no_grad(): # to improve efficiency model = ghn(model) # predict parameters of the model model.eval() model.eval() # set to the eval mode to disable dropout, etc.

model.apply(bn_set_train)
top1, top5 = infer(model.to(args.device), val_loader, verbose=False)
print(f'top1: {top1} and top5:{top5} ')`
bknyaz commented 4 months ago

You need to fine-tune the predicted parameters on CIFAR-10, because our GHN-3 predicts ImageNet parameters.