NVIDIA / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
8.41k stars 1.4k forks source link

opt_level='O1' wrong loading checkpoint behavior when loaded after evaluate function #796

Open Synicix opened 4 years ago

Synicix commented 4 years ago

There seems to be a issue with amp checkpoint loading issue when amp is set to opt_level='O1'.

It seems to occur if the code logic follow this:

The output is:

Loss at epoch 1 : 4.19111598146205
Loss at epoch 2 : 3.7329235222874857
Loss at epoch 3 : 3.4954240723532073
Loss at epoch 4 : 3.334823436883031
Loss at epoch 5 : 3.2206918110652847
Test Score at epoch 5 : 3.27591894865036
Model Saved
Loss at epoch 6 : 3.1195735213707905
Loss at epoch 7 : 3.0338400857789174
Loss at epoch 8 : 2.969289777230243
Loss at epoch 9 : 2.8997164028031484
Loss at epoch 10 : 2.849334267937407
Test Score at epoch 10 : 3.000475114583969
Model Loaded
Test Score at epoch 5 loaded from checkpoint: 3.000475114583969

Which is wrong where it seem that the loaded function didn't change the weights at all.

However, if I don't call evaluate before loading the model then incorrect model loading doesn't occur as in:

The output is:

Loss at epoch 1 : 4.207225557492704
Loss at epoch 2 : 3.792082970239678
Loss at epoch 3 : 3.5438545029990527
Loss at epoch 4 : 3.3556719957565773
Loss at epoch 5 : 3.221576902331138
Test Score at epoch 5 : 3.234942561388016
Model Saved
Loss at epoch 6 : 3.1069990408663846
Loss at epoch 7 : 3.0254530724214046
Loss at epoch 8 : 2.957716809243572
Loss at epoch 9 : 2.9037014160837447
Loss at epoch 10 : 2.8454051285373922
Model Loaded
Test Score at epoch 5 loaded from checkpoint: 3.234942561388016

Which is correct

Here is the reference code to reproduce the error:

import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 100)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
from apex import amp
import time

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR100(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=256,
                                          shuffle=True, num_workers=4)

testset = torchvision.datasets.CIFAR100(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=256,
                                         shuffle=False, num_workers=4)

device = torch.device('cuda:0')

model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters(), eps=1e-4)
criterion = nn.CrossEntropyLoss()

model, optimizer = amp.initialize(model, optimizer, opt_level='O1')

print(model)

# Evaluate function
def evaluate():
    model.eval()

    with torch.no_grad():
        running_loss = []
        for data in testloader:
            inputs = data[0].to(device)
            targets = data[-1].to(device)

            outputs = model(inputs)
            loss = criterion(outputs, data[-1].to(device))
            running_loss.append(loss.item())

    model.train()
    return np.array(running_loss).mean()

# Save model function
def save_model():
    checkpoint = dict(model=model.state_dict(), optimizer=optimizer.state_dict(), amp=amp.state_dict())
    torch.save(checkpoint, 'temp.pth')
    print('Model Saved')

def load_model():
    checkpoint = torch.load('temp.pth', map_location=device)
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    amp.load_state_dict(checkpoint['amp'])
    print('Model Loaded')

for i in range(100):
    running_loss = []
    for data in trainloader:
        inputs = data[0].to(device)
        targets = data[-1].to(device)

        outputs = model(inputs)
        loss = criterion(outputs, data[-1].to(device))
        with amp.scale_loss(loss, optimizer) as scale_loss:
            scale_loss.backward()

        running_loss.append(loss.item())

        optimizer.step()
        optimizer.zero_grad()

    print('Loss at epoch ' + str(i + 1) + ' : ' + str(np.array(running_loss).mean()))

    # break out after 10 epochs
    if (i + 1) == 5:
        print('Test Score at epoch ' + str(i + 1) + ' : ' + str(evaluate()))
        save_model()

    if (i + 1) == 10:
        #print('Test Score at epoch ' + str(i + 1) + ' : ' + str(evaluate()))
        load_model()
        print('Test Score at epoch ' + str(5) + ' loaded from checkpoint: ' + str(evaluate()))
        break
Synicix commented 4 years ago

Note that for O0, O2, and O3, this doesn't occur