When resuming an interrupted run, the following should be restored:
the point in the learning schedule that we left off at
the state of the half precision gradient scaler
the state of the optimiser
the state of the dataset sampler
Here's an example for the optimiser state:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
device = 'cuda'
model = models.resnet18()
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
data = torch.randn(100, 3, 224, 224, device=device)
target = torch.randint(0, 1000, (100, ), device=device)
nb_epochs = 10
for epoch in range(nb_epochs):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
print('Epoch {}, loss {}'.format(epoch, loss.item()))
# Create reference loss
model.eval()
with torch.no_grad():
output = model(data)
ref_loss = criterion(output, target)
print('reference loss {}'.format(ref_loss))
# restore
checkpoint = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict()
}
model = models.resnet18()
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
model.eval()
output = model(data)
loss = criterion(output, target)
print('restored loss {}'.format(loss.item()))
print('abs error {}'.format((ref_loss - loss).abs().item()))
Why
As we train on larger datasets, the training time will increase to be much longer. Runs may be interrupted due to various reasons, and in these cases we don't want to lose all the work that we did up to that point.
Acceptance Criteria
[x] Identified all of the state we need to save
[x] All of the state is correctly checkpointed
[x] Checkpoints are correctly restored
[ ] An example training run could be interrupted and restored and everything looks good
What
When resuming an interrupted run, the following should be restored:
Here's an example for the optimiser state:
Why
As we train on larger datasets, the training time will increase to be much longer. Runs may be interrupted due to various reasons, and in these cases we don't want to lose all the work that we did up to that point.
Acceptance Criteria