tfjgeorge / nngeometry

{KFAC,EKFAC,Diagonal,Implicit} Fisher Matrices and finite width NTKs in PyTorch
https://nngeometry.readthedocs.io
MIT License
203 stars 20 forks source link

Error with FIM for resnet18 #40

Closed TLESORT closed 2 years ago

TLESORT commented 2 years ago

Code:

import numpy as np
import torch
from torchvision import models
from torch.utils.data import DataLoader
from continuum.datasets import InMemoryDataset
from nngeometry.metrics import FIM
from nngeometry.object import PMatDiag

random_x_data = np.random.randint(0, 255, size=(20, 264, 264, 3))
random_y_data = np.arange(20)
data = InMemoryDataset(random_x_data, random_y_data).to_taskset()

model = models.resnet18(pretrained=True).cuda()

fisher_loader = DataLoader(data, batch_size=1, shuffle=True, num_workers=6)

fim = FIM(model=model.eval(),
         loader=fisher_loader,
         representation=PMatDiag,
         n_output=10,
         variant='classif_logits',
         device='cuda')

Error: image

tfjgeorge commented 2 years ago

This is solved by #42