xxxnell / how-do-vits-work

(ICLR 2022 Spotlight) Official PyTorch implementation of "How Do Vision Transformers Work?"
https://arxiv.org/abs/2202.06709
Apache License 2.0
798 stars 77 forks source link

Code for Alter-ResNet-50 #1

Closed DarrenIm closed 2 years ago

DarrenIm commented 2 years ago

Hi, awesome work and really good points about MSAs! I'm very much interested in the AlterNet mentioned in the paper(based on ResNet-50 and SwinTBlock), but I cant find the implementation of it in this repo. Did I miss? If not, can you release the code maybe?

Thanks a lot!

xxxnell commented 2 years ago

Hi, Darren. Thank you for your question.

The AlterNet code is available at models > alternet.py (in default transformer branch, not master branch). The model uses a very simplified Swin Transformer block, and the default configuration is for ImageNet-1K. If you would like to use AlterNet for CIFAR-100, please refer to the snippet below:

from functools import partial

import torch
import models
import models.alternet as alternet
import models.preresnet_dnn_block as preresnet_dnn

from models.alternet import AttentionBlockB, StemB

model = alternet.AlterNet(
    preresnet_dnn.Bottleneck, AttentionBlockB, stem=partial(StemB, pool=False),
    num_blocks=[3, 4, 6, 4], num_blocks2=[0, 1, 1, 2],
    num_classes=100, window_size=4, heads=[3, 6, 12, 24], sd=0.1, 
    name="alternet_50",
)
models.stats(model)

Again, thank you for your interests in our paper. If you require any further information, feel free to contact me!

luluenen commented 2 years ago

Hi, I am very interested your work of AlterNet. I tested your code given as above, using the AlterNet for CIFAR-100. but I got an error :

File "./how-do-vits-work/models/alternet.py", line 39, in forward mask = mask + self.pos_embedding[self.rel_index[:, :, 0], self.rel_index[:, :, 1]]

IndexError: tensors used as indices must be long, byte or bool tensors

is it normal ?

Thanks a lot, Lu

xxxnell commented 2 years ago

Hi,

I tested the code again on pytorch/pytorch:1.9.0-cuda11.1-cudnn8-runtime docker image, and the snippet above worked. Please check your Python/PyTorch version. If the issue is still not resolved after resolving the dependencies, please let me know again.

Or, this Colab notebook may be helpful.

Thank you for your support!

xxxnell commented 2 years ago

Closing this issue based on the comment above. Please feel free to reopen this issue if the problem still exists.

Lmy0217 commented 2 years ago

@xxxnell @luluenen I got the same error, and fixed it by modifying https://github.com/xxxnell/how-do-vits-work/blob/cea0635c77ca9a22882feb965995c0eac917ebdd/models/alternet.py#L49 to

i = torch.tensor(np.array([[x, y] for x in range(window_size) for y in range(window_size)]), dtype=torch.long)