tfjgeorge / nngeometry

{KFAC,EKFAC,Diagonal,Implicit} Fisher Matrices and finite width NTKs in PyTorch
https://nngeometry.readthedocs.io
MIT License
203 stars 20 forks source link

RuntimeError: Shape is invalid for input of size #56

Closed xand-stapleton closed 1 year ago

xand-stapleton commented 1 year ago

I'm trying to use the latest git release of NNGeometry's FIM to find the Fisher metric of my trivial model. As a stupidly basic example which recreates my problem, I create a model which has a single Linear layer, a single training sample, and solves the matrix equation Ax=b, where A is a 3x3 matrix, whilst x, b are 3x1 col. vectors.

Here's my code (it's not meant for anything functional, it's just to replicate my problem):

import torch
import torch.nn as nn
import torch.optim as optim

class Net(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Net, self).__init__()

        self.linear = nn.Linear(input_dim, output_dim, bias=False)

    def forward(self, x):
        out = self.linear(x)
        return out

# Define the training data
A = torch.tensor([[1., 2., 3.],
                  [4., 5., 6.],
                  [7., 8., 9.]])

b = torch.tensor([[52.],
                  [124.],
                  [196.]])
# Define the model and the optimizer
model = Net(input_dim=9, output_dim=3)
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Train the model
for epoch in range(2000):
    optimizer.zero_grad()
    y_pred = model(A.view(9))
    print(A@y_pred)
    loss = nn.MSELoss(reduction='sum')(A@y_pred.view((3,1)), b)
    loss.backward()
    optimizer.step()

# Evaluate the model
with torch.no_grad():
    y_pred = model(A.reshape(9))
    print("Solution:\n", y_pred)

Now I create a simple dataloader with that single training sample in (just as a proof of concept):

from torch.utils.data import DataLoader, Dataset

class TrivialDataset(Dataset):
    def __init__(self):
        self.data = torch.tensor([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]).reshape(1,9)
    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return len(self.data)

# Create the Dataloader
batch_size = 1
dataset = TrivialDataset()
loader = DataLoader(dataset, batch_size=batch_size)

Now if I try to find the FIM:

from nngeometry.metrics import FIM
from nngeometry.object import PMatDense

fisher_metric = FIM(model, loader, n_output=1, variant='regression', representation=PMatDense, device='cpu')

There's a runtime error:

File [~/miniconda3/envs/torch/lib/python3.10/site-packages/nngeometry/generator/jacobian/__init__.py:77](https://file+.vscode-resource.vscode-cdn.net/Users/as/Desktop/tmp/nngeometry/nngeometry-examples/display_and_timings/~/miniconda3/envs/torch/lib/python3.10/site-packages/nngeometry/generator/jacobian/__init__.py:77), in Jacobian.get_covariance_matrix(self, examples)
     75 inputs.requires_grad = True
     76 bs = inputs.size(0)
---> 77 output = self.function(*d).view(bs, self.n_output) \
     78     .sum(dim=0)
     79 for i in range(self.n_output):
     80     self.grads.zero_()

RuntimeError: shape '[9, 1]' is invalid for input of size 3

I think this comes about because FIM is trying to reshape the output based on the input size. Is this correct?

Thanks

tfjgeorge commented 1 year ago

You need to pass n_output=3 when instantiating your object using the FIM helper. That way, the generator will expect every minibatch example to produce an output of size 3.

xand-stapleton commented 1 year ago

Thanks for your quick reply! Unfortunately that doesn't change the error, except the 1 becomes a 3:

RuntimeError: shape '[9, 3]' is invalid for input of size 3

tfjgeorge commented 1 year ago

I think this comes from the fact that you are using a single example, instead of a minibatch of several examples. NNGeometry was designed to work with datasets with many examples.

xand-stapleton commented 1 year ago

Update: Yep, that's the problem, making the following change solves it:

class TrivialDataset(Dataset):
    def __init__(self):
        self.data = torch.arange(9, dtype=torch.float32).view(1,1,9)
    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return len(self.data)

Thanks! :)