tfjgeorge / nngeometry

{KFAC,EKFAC,Diagonal,Implicit} Fisher Matrices and finite width NTKs in PyTorch
https://nngeometry.readthedocs.io
MIT License
206 stars 20 forks source link

Error with float64 tensors #70

Closed Xuzzo closed 11 months ago

Xuzzo commented 11 months ago

Hello and thanks for your work. Ekfac seems to have issues with models that work on double precision. Here is a code to reproduce it:

from nngeometry.metrics import FIM
  from nngeometry.object import PMatEKFAC
  import torch as th

  dtype = th.float64

  class SimpleModel(th.nn.Module):

      def __init__(
          self,
          n_input: int,
          n_output: int,
      ):
          super().__init__()
          self.fc1 = th.nn.Linear(n_input, n_output, bias=True, dtype=dtype)

      def forward(self, x):
          return th.nn.Softmax(dim=-1)(self.fc1(x))

  if __name__ == "__main__":
      model = SimpleModel(10, 3)
      dataset = th.utils.data.TensorDataset(th.randn(100, 10, dtype=dtype), th.randint(0, 3, (100,), dtype=th.long))
      loader = th.utils.data.DataLoader(dataset, batch_size=10)
      F_ekfac = FIM(model, loader, PMatEKFAC, 3, variant='classif_logits')
      F_ekfac.update_diag(loader)

I get "RuntimeError: expected scalar type Double but found Float"

tfjgeorge commented 11 months ago

Hi, thanks for pointing this out.

This PR: https://github.com/tfjgeorge/nngeometry/pull/71 should do it.

It still needs a little bit more testing before it gets merged to master but meanwhile you can use it.

Xuzzo commented 11 months ago

Works now! thanks a lot