microsoft / Cream

This is a collection of our NAS and Vision Transformer work.
MIT License
1.66k stars 225 forks source link

Running TinyViT - low accuracy #147

Closed cairus closed 1 year ago

cairus commented 1 year ago

Hi,

I am interested in using TinyViT due to its good performance and small size to solve a downstream task. However, I am unable to verify that the model works as intended. I am running the model on a laptop without GPU, with 1000 ImageNet 2011 validation images and get 0.10% accuracy, which is what I'd expect to get with random guessing.

Here's the gist of the code I am running:

labels = ... # ImageNet 2011 validation dataset labels, integers from 1 to 1000
dataset_path = ... # path to the ImageNet 2011 validation dataset

classifier = tiny_vit.tiny_vit_21m_224(pretrained=True)
cropper = torchvision.transforms.RandomResizedCrop(224, scale=(1.0, 1.0), ratio=(1.0, 1.0)) # necessary because images size varies
transform = transforms.Compose([transforms.ToTensor()]) 

preds = []
with torch.no_grad():
    for i, image_filename in enumerate(os.listdir(dataset_path)):
        if i >= 1000:
            break

        image = Image.open(dataset_path + image_filename)
        input_array = np.array(cropper(image)) # resulting array has shape 224, 224, 3
        tensor = transform(input_array).reshape((1, 3, 224, 224))
        output = classifier(tensor).detach().numpy()
        prediction = np.argmax(output[0]) + 1
        preds.append(prediction)

accuracy = ... # compare preds and labels

I am unable to identify the mistake. TinyViT 1000 output classes are [1 kit fox, 2 English setter, 3 Australian terrier, ...], offset by 1, right? I have manually checked first 5 images and labels in my dataset, they match, so there no mistake there. I did not see that it is required to use model.eval() for TinyViT, and I have tried with or without, both times the accuracy is around 0. I have verified that the register_tiny_vit_model() function is called, so weights should be downloaded and loaded. I have also verified that the image cropping works as intended.

The model expects 4D input, so reshaping to (1, 3, 224, 224) should be correct, no?

I would appreciate any kind of help!

Best, cairus

wkcn commented 1 year ago

Thanks for your attention to our work!

You can try evaluate the model with CPU by the following command:

python -m torch.distributed.launch --nproc_per_node 1 main.py --cfg configs/22kto1k/tiny_vit_21m_22kto1k.yaml --data-path ./ImageNet --batch-size 128 --eval --resume ./checkpoints/tiny_vit_21m_22kto1k_distill.pth --only-cpu

Here is the file synset_words.txt including the 1,000 classificiation names of ImageNet-1k.

TinyViT should be evaluated in the model.eval() mode, since it contains BatchNorm layers.

Data augmentation for TinyViT

from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision.transforms import InterpolationMode
IMG_SIZE = 224
interpolation = InterpolationMode.BICUBIC
mean, std = (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)

t = []
size = int((256 / 224) * IMG_SIZE)
t.append(
    transforms.Resize(size, interpolation=interpolation),
)
t.append(transforms.CenterCrop(IMG_SIZE))

t.append(transforms.ToTensor())
t.append(transforms.Normalize(mean, std))
transform = transforms.Compose(t)
cairus commented 1 year ago

Thank you for the synset_words.txt, I had used another list. Thank you also for the model.eval() suggestion. The model is now working with reported accuracy. Solved!