f-dangel / sirfshampoo

[ICML 2024] SIRFShampoo: Structured inverse- and root-free Shampoo in PyTorch (https://arxiv.org/abs/2402.03496)
https://sirfshampoo.readthedocs.io
MIT License
12 stars 1 forks source link

[BUG] Pre-conditioner has wrong device when loading with different `map_location` #34

Open bonevbs opened 1 day ago

bonevbs commented 1 day ago

a simple reproducer:

import torch
from torch import nn
import sirfshampoo

device = torch.device("cuda:0")

class DebugNet(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()

        # create dummy param so that it won't crash in optimizer instantiation
        self.factor = nn.Parameter(torch.randn((1), dtype=torch.float32))

    def forward(self, x):
        return self.factor * x

model = DebugNet().to(device)

optimizer = sirfshampoo.SIRFShampoo(model)

x = torch.randn(10).to(device)
y = model(x)
loss = (y - x).pow(2).sum()

loss.backward()
optimizer.step()

# save the state
state_dict = model.state_dict()
# add model state dict to store dict
store_dict = {"model_state": state_dict}
store_dict["optimizer_state_dict"] = optimizer.state_dict()

checkpoint_fname = "dummy_ckpt.pt"

torch.save(store_dict, checkpoint_fname)

# restore it
checkpoint = torch.load(checkpoint_fname, map_location="cpu", weights_only=False)

model.load_state_dict(checkpoint["model_state"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

# run the same thing
x = torch.randn(10).to(device)
y = model(x)
loss = (y - x).pow(2).sum()

# this will error now
loss.backward()
optimizer.step()
f-dangel commented 1 day ago

Hi, thanks for reporting this!

I pushed a patch here (#35). Could you install from there and verify that this fixes your issue? I will then further refactor and merge into main.