Closed nesaboz closed 1 year ago
Note: I do use Python 3.10, so not sure if this is an issue in 3.7.
I ran into a problem by getting following image:
After some digging, seems like model_state
gets updated after training. This causes next sbs
instance (one for Adam
) to start from previous model parameters (which load_state_dict
was trying to prevent). Consider following code:
torch.manual_seed(42)
model = nn.Sequential()
model.add_module('linear', nn.Linear(1, 1))
loss_fn = nn.MSELoss(reduction='mean')
model_state = model.state_dict()
optimizer = optim.SGD(model.parameters(), lr= 0.1)
sbs = StepByStep(model, loss_fn, optimizer)
sbs.set_loaders(train_loader, val_loader)
print(model_state)
sbs.train(10)
print(model_state)
OrderedDict([('linear.weight', tensor([[0.7645]])), ('linear.bias', tensor([0.8300]))]) OrderedDict([('linear.weight', tensor([[1.5287]])), ('linear.bias', tensor([1.2382]))])
One solution is to just deepcopy
model before any training and reuse it for load_state_dict
. After this:
Hi @nesaboz
Thank you for submitting this PR. You're absolutely right, it doesn't work properly in Python 3.10. I always test it on Google Colab (which currently uses Python 3.8) so I didn't notice this. Also, thank you for spotting and removing the duplicate function :-)
Best, Daniel
model_state
will change after training,deepcopy
is one way to reload the original one.