Hey @nikvaessen This is very odd, I'll look into it.
Btw, in case you didn't know the recommended way to load in Fabric is
fabric.load("first_session.ckpt", {"network": network, "opt": opt})
because this generalizes across all strategies and accelerators + offers a convenient way to make scripts stateful in general. And this will pass your assertion. So you can use this way as a workaround until I make the bugfix.
Thanks for reporting!
I've tried the following modification to the reloading part of the reproduction code sample:
def second_session():
fabric = lightning.Fabric(accelerator="gpu", devices=1)
network = Network()
opt = torch.optim.Adam(network.parameters())
network, opt = fabric.setup(network, opt)
state = {"network": network, "opt": opt}
remainder = fabric.load("first_session.ckpt", state)
print("remainder:", remainder)
print("wrapper", opt.state_dict())
print("optimizer", opt.optimizer.state_dict())
print("checkpoint\n", torch.load("first_session.ckpt"))
assert len(opt.state_dict()["state"]) >= 0
This still results in
remainder: {}
wrapper {'state': {}, 'param_groups': [{'lr': 0.001, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False, 'foreach': None, 'capturable': False, 'differentiable': False, 'fused': None, 'params': [0, 1]}]}
optimizer {'state': {}, 'param_groups': [{'lr': 0.001, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False, 'foreach': None, 'capturable': False, 'differentiable': False, 'fused': None, 'params': [0, 1]}]}
If I run my code with pip install git+
, the state dictionary is correctly loaded. So thanks for the bugfix :)
@nikvaessen Thanks for confirming!
Bug description
is called on an optimizer object returned byfabric.setup(...)
, the resultingstate_dict
will be empty.What version are you seeing the problem on?
How to reproduce the bug
Error messages and logs
The checkpoint (
:The result of calling
after callingopt.load_state_dict()
Current environment
