xxxnell / how-do-vits-work

(ICLR 2022 Spotlight) Official PyTorch implementation of "How Do Vision Transformers Work?"
https://arxiv.org/abs/2202.06709
Apache License 2.0
798 stars 77 forks source link

AlterNet on CIFAR10 #38

Closed 23Uday closed 1 year ago

23Uday commented 1 year ago

Hi, while trying to setup an alternet_18 to train on CIFAR10 I used the default config in models/alternet.py, which would be the following.

AlterNet(preresnet_dnn.BasicBlock, AttentionBasicBlockB, stem=partial(StemB, pool=stem),
                    num_blocks=(2, 2, 2, 2), num_blocks2=(0, 1, 1, 1), heads=(3, 6, 12, 24),
                    num_classes=num_classes, name=name, **block_kwargs)``

Upon doing so I get the following error Input tensor shape: torch.Size([128, 128, 4, 4]). Additional info: {'p1': 7, 'p2': 7}. Shape mismatch, can't divide axis of length 4 in chunks of 7 which is thrown by x = rearrange(x, "b c (n1 p1) (n2 p2) -> (b n1 n2) c p1 p2", p1=p, p2=p) in the Class LocalAttention. This is happening because the default window size is 7, which doesn't work for 3 x 32 x 32 input images of CIFAR10. Could you point me to a setup used to train AlterNet for CIFAR10/100 images? Thank you

xxxnell commented 1 year ago

Hi @23Uday,

Thank you for reaching out, and I apologize for the delayed response. I was out of the office.

For CIFAR, you can use a window size of 4. That is, you can use the following configuration:

block_kwargs = {  # for CIFAR
    "image_size": 32, 
    "patch_size": 2,
    "window_size": 4
}

model = AlterNet(..., **block_kwargs)