feldberlin / wavenet

An unconditioned Wavenet implementation with fast generation.
3 stars 0 forks source link

Improve resuming of long runs #10

Closed purzelrakete closed 3 years ago

purzelrakete commented 3 years ago

What

When resuming an interrupted run, the following should be restored:

Here's an example for the optimiser state:

import torch
import torch.nn as nn
import torch.optim as optim

import torchvision.models as models

device = 'cuda'
model = models.resnet18()
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

data = torch.randn(100, 3, 224, 224, device=device)
target = torch.randint(0, 1000, (100, ), device=device)

nb_epochs = 10
for epoch in range(nb_epochs):
    optimizer.zero_grad()
    output = model(data)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()
    print('Epoch {}, loss {}'.format(epoch, loss.item()))

# Create reference loss
model.eval()
with torch.no_grad():
    output = model(data)
    ref_loss = criterion(output, target)
print('reference loss {}'.format(ref_loss))

# restore
checkpoint = {
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict()
}

model = models.resnet18()
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])

model.eval()

output = model(data)
loss = criterion(output, target)
print('restored loss {}'.format(loss.item()))
print('abs error {}'.format((ref_loss - loss).abs().item()))

Why

As we train on larger datasets, the training time will increase to be much longer. Runs may be interrupted due to various reasons, and in these cases we don't want to lose all the work that we did up to that point.

Acceptance Criteria