Loss at epoch 1 : 4.19111598146205
Loss at epoch 2 : 3.7329235222874857
Loss at epoch 3 : 3.4954240723532073
Loss at epoch 4 : 3.334823436883031
Loss at epoch 5 : 3.2206918110652847
Test Score at epoch 5 : 3.27591894865036
Model Saved
Loss at epoch 6 : 3.1195735213707905
Loss at epoch 7 : 3.0338400857789174
Loss at epoch 8 : 2.969289777230243
Loss at epoch 9 : 2.8997164028031484
Loss at epoch 10 : 2.849334267937407
Test Score at epoch 10 : 3.000475114583969
Model Loaded
Test Score at epoch 5 loaded from checkpoint: 3.000475114583969
Which is wrong where it seem that the loaded function didn't change the weights at all.
However, if I don't call evaluate before loading the model then incorrect model loading doesn't occur as in:
Loss at epoch 1 : 4.207225557492704
Loss at epoch 2 : 3.792082970239678
Loss at epoch 3 : 3.5438545029990527
Loss at epoch 4 : 3.3556719957565773
Loss at epoch 5 : 3.221576902331138
Test Score at epoch 5 : 3.234942561388016
Model Saved
Loss at epoch 6 : 3.1069990408663846
Loss at epoch 7 : 3.0254530724214046
Loss at epoch 8 : 2.957716809243572
Loss at epoch 9 : 2.9037014160837447
Loss at epoch 10 : 2.8454051285373922
Model Loaded
Test Score at epoch 5 loaded from checkpoint: 3.234942561388016
Which is correct
Here is the reference code to reproduce the error:
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 100)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
from apex import amp
import time
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR100(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=256,
shuffle=True, num_workers=4)
testset = torchvision.datasets.CIFAR100(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=256,
shuffle=False, num_workers=4)
device = torch.device('cuda:0')
model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters(), eps=1e-4)
criterion = nn.CrossEntropyLoss()
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
print(model)
# Evaluate function
def evaluate():
model.eval()
with torch.no_grad():
running_loss = []
for data in testloader:
inputs = data[0].to(device)
targets = data[-1].to(device)
outputs = model(inputs)
loss = criterion(outputs, data[-1].to(device))
running_loss.append(loss.item())
model.train()
return np.array(running_loss).mean()
# Save model function
def save_model():
checkpoint = dict(model=model.state_dict(), optimizer=optimizer.state_dict(), amp=amp.state_dict())
torch.save(checkpoint, 'temp.pth')
print('Model Saved')
def load_model():
checkpoint = torch.load('temp.pth', map_location=device)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
amp.load_state_dict(checkpoint['amp'])
print('Model Loaded')
for i in range(100):
running_loss = []
for data in trainloader:
inputs = data[0].to(device)
targets = data[-1].to(device)
outputs = model(inputs)
loss = criterion(outputs, data[-1].to(device))
with amp.scale_loss(loss, optimizer) as scale_loss:
scale_loss.backward()
running_loss.append(loss.item())
optimizer.step()
optimizer.zero_grad()
print('Loss at epoch ' + str(i + 1) + ' : ' + str(np.array(running_loss).mean()))
# break out after 10 epochs
if (i + 1) == 5:
print('Test Score at epoch ' + str(i + 1) + ' : ' + str(evaluate()))
save_model()
if (i + 1) == 10:
#print('Test Score at epoch ' + str(i + 1) + ' : ' + str(evaluate()))
load_model()
print('Test Score at epoch ' + str(5) + ' loaded from checkpoint: ' + str(evaluate()))
break
There seems to be a issue with amp checkpoint loading issue when amp is set to opt_level='O1'.
It seems to occur if the code logic follow this:
The output is:
Which is wrong where it seem that the loaded function didn't change the weights at all.
However, if I don't call evaluate before loading the model then incorrect model loading doesn't occur as in:
The output is:
Which is correct
Here is the reference code to reproduce the error: