Open Peter-SungwooCho opened 1 year ago
Hi @Peter-SungwooCho, The reason I think the model may not be showing the desired accuracy during inference is if different hyperparameters (such as hidden_dim
) were used during training and when loading the saved model." This is only a random guess. But it would be helpful if you could show me the code where you have tried to load the saved model. (keeping in mind The problem is only while loading the saved model and there is no overfitting).
Hi @Peter-SungwooCho, The reason I think the model may not be showing the desired accuracy during inference is if different hyperparameters (such as
hidden_dim
) were used during training and when loading the saved model." This is only a random guess. But it would be helpful if you could show me the code where you have tried to load the saved model. (keeping in mind The problem is only while loading the saved model and there is no overfitting).
Hello, I face the same problem. This is my code:
device = torch.device("cuda")
full_model = ResNet18().to(device)
#print(full_model)
checkpoint = torch.load("./checkpoint/ckpt.pth")
print(checkpoint)
full_model.load_state_dict(checkpoint['net'],strict=False) #,strict=False
full_model.to(device)
full_model.eval()
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
test_dataset = torchvision.datasets.CIFAR10("./data", train=False, transform=transform_test, download=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=1)
correct = 0
total = 0
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(test_loader):
inputs, targets = inputs.to(device), targets.to(device)
outputs = full_model(inputs)
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
progress_bar(batch_idx, len(test_loader), 'Acc: %.3f%% (%d/%d)'
% (100.*correct/total, correct, total))
After training with resnet 18, if you proceed with inference, the accuracy is only 10%. What could be the problem?