s-chh / PyTorch-Scratch-Vision-Transformer-ViT

Simple and easy to understand PyTorch implementation of Vision Transformer (ViT) from scratch with detailed steps. Tested on small datasets like MNIST, FashionMNIST, SVHN and CIFAR10.
MIT License
112 stars 15 forks source link
pytorch-vit scratch simple transformer transformer-cifar10 transformer-mnist vision-transformer vit vit-cifar vit-cifar10 vit-fashionmnist vit-mnist vit-scratch vit-simple vit-svhn

Vision Transformer from Scratch in PyTorch

Simplified Scratch Pytorch Implementation of Vision Transformer (ViT) with Detailed Steps (Refer to model.py)

This repo uses a scaled-down version of the original ViT. Tested on small datasets like MNIST, CIFAR10, etc., using a smaller patch size.

Key Points:



Run commands (also available in scripts.sh):

Dataset Run command Test Acc
MNIST python main.py --dataset mnist --epochs 100 99.5
Fashion MNIST python main.py --dataset fmnist 92.3
SVHN python main.py --dataset svhn --n_channels 3 --image_size 32 --embed_dim 128 96.2
CIFAR10 python main.py --dataset cifar10 --n_channels 3 --image_size 32 --embed_dim 128 86.3 (82.5 w/o RandAug)
CIFAR100 python main.py --dataset cifar100 --n_channels 3 --image_size 32 --embed_dim 128 59.6 (55.8 w/o RandAug)



Transformer Config:

Config MNIST and FMNIST SVHN and CIFAR
Input Size 1 X 28 X 28 3 X 32 X 32
Patch Size 4 4
Sequence Length 7*7 = 49 8*8 = 64
Embedding Size 64 128
Parameters 210k 820k
Num of Layers 6 6
Num of Heads 4 4
Forward Multiplier 2 2
Dropout 0.1 0.1