import torch
from torch import nn
import sirfshampoo
device = torch.device("cuda:0")
class DebugNet(nn.Module):
def __init__(self, **kwargs):
super().__init__()
# create dummy param so that it won't crash in optimizer instantiation
self.factor = nn.Parameter(torch.randn((1), dtype=torch.float32))
def forward(self, x):
return self.factor * x
model = DebugNet().to(device)
optimizer = sirfshampoo.SIRFShampoo(model)
x = torch.randn(10).to(device)
y = model(x)
loss = (y - x).pow(2).sum()
loss.backward()
optimizer.step()
# save the state
state_dict = model.state_dict()
# add model state dict to store dict
store_dict = {"model_state": state_dict}
store_dict["optimizer_state_dict"] = optimizer.state_dict()
checkpoint_fname = "dummy_ckpt.pt"
torch.save(store_dict, checkpoint_fname)
# restore it
checkpoint = torch.load(checkpoint_fname, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model_state"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
# run the same thing
x = torch.randn(10).to(device)
y = model(x)
loss = (y - x).pow(2).sum()
# this will error now
loss.backward()
optimizer.step()
a simple reproducer: