kuangliu / pytorch-cifar

95.47% on CIFAR10 with PyTorch
MIT License
5.94k stars 2.14k forks source link

Performance degradation when loading a model after saving it #162

Open Peter-SungwooCho opened 1 year ago

Peter-SungwooCho commented 1 year ago

After training with resnet 18, if you proceed with inference, the accuracy is only 10%. What could be the problem?

varunponda commented 1 year ago

Hi @Peter-SungwooCho, The reason I think the model may not be showing the desired accuracy during inference is if different hyperparameters (such as hidden_dim) were used during training and when loading the saved model." This is only a random guess. But it would be helpful if you could show me the code where you have tried to load the saved model. (keeping in mind The problem is only while loading the saved model and there is no overfitting).

MatsuoTakuyaD commented 1 year ago

Hi @Peter-SungwooCho, The reason I think the model may not be showing the desired accuracy during inference is if different hyperparameters (such as hidden_dim) were used during training and when loading the saved model." This is only a random guess. But it would be helpful if you could show me the code where you have tried to load the saved model. (keeping in mind The problem is only while loading the saved model and there is no overfitting).

Hello, I face the same problem. This is my code:

device = torch.device("cuda")

full_model = ResNet18().to(device)
#print(full_model)
checkpoint = torch.load("./checkpoint/ckpt.pth")
print(checkpoint)
full_model.load_state_dict(checkpoint['net'],strict=False)  #,strict=False
full_model.to(device)
full_model.eval()

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
test_dataset = torchvision.datasets.CIFAR10("./data", train=False, transform=transform_test, download=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=1)

correct = 0
total = 0
with torch.no_grad():
    for batch_idx, (inputs, targets) in enumerate(test_loader):
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = full_model(inputs)

        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar(batch_idx, len(test_loader), 'Acc: %.3f%% (%d/%d)'
                     % (100.*correct/total, correct, total))