openai / CLIP

CLIP (Contrastive Language-Image Pretraining), Predict the most relevant text snippet given an image
MIT License
26.19k stars 3.35k forks source link

Evaluation of CLIP on Mnist is only 10% accuracy #412

Open GitOutOfMyBed opened 11 months ago

GitOutOfMyBed commented 11 months ago

According to the paper, they were able to achieve 80% accuracy on MNIST. However, I tested mnist on clip and get only 10-20%. Does anyone know what I am doing wrong here?

import torch
import clip
from PIL import Image
import torchvision
import torchvision.transforms as transforms
import numpy as np
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

training_set = torchvision.datasets.MNIST('./data', train=True, transform=preprocess, download=True)
validation_set = torchvision.datasets.MNIST('./data', train=False, transform=preprocess, download=True)

trainloader = torch.utils.data.DataLoader(training_set, batch_size=32,
                                          shuffle=True, num_workers=0)
testloader = torch.utils.data.DataLoader(validation_set, batch_size=32,
                                         shuffle=False, num_workers=0)

classes_mnist = ['handwritten digit 0', 'handwritten digit 1', 'handwritten digit 2', 'handwritten digit 3', 'handwritten digit 4',
        'handwritten digit 5', 'handwritten digit 6', 'handwritten digit 7', 'handwritten digit 8', 'handwritten digit 9']
classes = ['0', '1', '2', '3', '4',
    '5', '6', '7', '8', '9']    
classes = ['zero', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine']
text = clip.tokenize(classes).to(device)    

correct = 0
total = 0
for i, j in testloader:
    labels = [classes[x] for x in j]
    with torch.no_grad():
        logits_per_image, logits_per_text = model(i, text)
        probs = logits_per_image.softmax(dim=-1).cpu().numpy()
    result = [classes[x] for x in np.argmax(probs, axis=1)]
    for i, j in zip(result, labels):
        if i == j:
            correct += 1
print(correct/len(testloader.dataset))
VimukthiRandika1997 commented 11 months ago

text = clip.tokenize(classes).to(device) change this one to something like this!

torch.cat([clip.tokenize(f"a photo of a {c}") for c in classes]).to(device)

The results on the paper depends on how they have set up the preprocessing data, I am not sure how they put it!

GitOutOfMyBed commented 11 months ago

I tried that and got better results but they were still nowhere near 88%.

mgupta70 commented 10 months ago

Performance varies a lot depending on the description template and model you choose. For instance, when you use: "a photo of the number {}", the top-1 accuracies over MNIST are:

RN50 - 51.28 RN101 - 44.34 RN50x4 - 56.64 ViT-B/32 - 40.11 ViT-B/16 - 53.95

Whereas, if you use 80 templates of Imagenet, you will get:

RN50 - 26.5 RN101 - 36.14 RN50x4 - 59.53 ViT-B/32 - 30.16 ViT-B/16 - 44.22