rishikksh20 / MLP-Mixer-pytorch

Unofficial implementation of MLP-Mixer: An all-MLP Architecture for Vision
MIT License
209 stars 27 forks source link

CIFAR training example #1

Closed loretoparisi closed 3 years ago

loretoparisi commented 3 years ago

Hello, thanks for this project! I'm trying to add a training code to mlp-mixer using the CIFAR dataset. I have added a transform then to adapt the images:

# Image resize 256
transform256 = T.Compose([
            T.Resize(256),
            T.CenterCrop(224),
            T.ToTensor(),
            T.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])

# training set
training_folder = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'training')

trainset = torchvision.datasets.CIFAR10(root=training_folder, train=True,
                                        download=True, transform=transform256)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)
# test set
testset = torchvision.datasets.CIFAR10(root=training_folder, train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)
# cifar classes
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# MLP Mixer
mixer_model = MLPMixer(in_channels=3, 
                image_size=224, 
                patch_size=16, 
                num_classes=1000,
                dim=512, 
                depth=8, 
                token_dim=256, 
                channel_dim=2048)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
mixer_model.to(torch.device(device))

and to try out image representation in the model:

for i, data in enumerate(trainloader, 0):
    inputs, labels = data
    inputs, labels = inputs.to(device), labels.to(device)
    print("Input:", inputs.shape)
    outputs = mixer_model(inputs)
    print(outputs)
    if i == 1:
        break

and I get

Input: torch.Size([4, 3, 224, 224])
MLPMixer out: torch.Size([4, 1000])

while a naive training code, I'm not actually sure if the input image resize to the model is correct. Thank you.

Oktai15 commented 3 years ago

@loretoparisi you need to change num_classes=1000 to num_classes=10.

loretoparisi commented 3 years ago

Thanks it works now!