aaron-xichen / pytorch-playground

Base pretrained models and datasets in pytorch (MNIST, SVHN, CIFAR10, CIFAR100, STL10, AlexNet, VGG16, VGG19, ResNet, Inception, SqueezeNet)
MIT License
2.62k stars 612 forks source link

why is the prediction not correct? #64

Open jS5t3r opened 11 months ago

jS5t3r commented 11 months ago

It always predicts 464 for every sample...

import torch
import pickle as pkl
import time
import numpy as np
import cv2 
import matplotlib.pyplot as plt
import torchvision.models as models
import torchvision.transforms as transforms

def str2img(str_b):
    return cv2.imdecode(np.fromstring(str_b, np.uint8), cv2.IMREAD_COLOR)

def load_pickle(path):
    begin_st = time.time()
    with open(path, 'rb') as f:
        print("Loading pickle object from {}".format(path))
        v = pkl.load(f)
    print("=> Done ({:.4f} s)".format(time.time() - begin_st))
    return v

d = load_pickle('val224_compressed.pkl')

img224 = 0
target224 = 0
for img, target in zip(d['data'], d['target']):
    img224 = str2img(img)
    target224 = target
    break

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

img_tensor = transforms.ToTensor()(img224) / 255.
normalized_image = normalize(img_tensor)

model = models.resnet18(pretrained=True).eval()

pred = model(normalized_image.unsqueeze(0))

print(pred.argmax(1), target224)