tfjgeorge / nngeometry

{KFAC,EKFAC,Diagonal,Implicit} Fisher Matrices and finite width NTKs in PyTorch
https://nngeometry.readthedocs.io
MIT License
203 stars 20 forks source link

shapes pbs (x2) #41

Closed TLESORT closed 2 years ago

TLESORT commented 2 years ago

Code:

import numpy as np
from nngeometry.layers import WeightNorm1d
from continuum.datasets import InMemoryDataset

classifier = WeightNorm1d(in_features=512, out_features=20) # btw Also fail with nn.Linear
random_x_data = np.random.randint(0, 255, size=(20, 512))
random_y_data = np.arange(20)
data = InMemoryDataset(random_x_data, random_y_data).to_taskset()
fisher_loader = DataLoader(data, batch_size=128, shuffle=True, num_workers=6)
fim = FIM(model=classifier,
         loader=fisher_loader,
         representation=PMatDiag,
         n_output=20,
         variant='classif_logits',
         device='cpu')

Error: image

If I solve the pb of view by modifying the weightnorm class, I get another error: image

(I just modified the forward function of WeightNorm1d with : )

def forward(self, input: Tensor) -> Tensor:
    input = input.view(-1, self.in_features)
    norm2 = (self.weight**2).sum(dim=1, keepdim=True) + self.eps
    return F.linear(input,
                    self.weight / torch.sqrt(norm2))
TLESORT commented 2 years ago

My bad it is probably from InMemoryDataset class