Closed DarrenIm closed 2 years ago
Hi, Darren. Thank you for your question.
The AlterNet code is available at models > alternet.py
(in default transformer
branch, not master
branch). The model uses a very simplified Swin Transformer block, and the default configuration is for ImageNet-1K. If you would like to use AlterNet for CIFAR-100, please refer to the snippet below:
from functools import partial
import torch
import models
import models.alternet as alternet
import models.preresnet_dnn_block as preresnet_dnn
from models.alternet import AttentionBlockB, StemB
model = alternet.AlterNet(
preresnet_dnn.Bottleneck, AttentionBlockB, stem=partial(StemB, pool=False),
num_blocks=[3, 4, 6, 4], num_blocks2=[0, 1, 1, 2],
num_classes=100, window_size=4, heads=[3, 6, 12, 24], sd=0.1,
name="alternet_50",
)
models.stats(model)
Again, thank you for your interests in our paper. If you require any further information, feel free to contact me!
Hi, I am very interested your work of AlterNet. I tested your code given as above, using the AlterNet for CIFAR-100. but I got an error :
File "./how-do-vits-work/models/alternet.py", line 39, in forward mask = mask + self.pos_embedding[self.rel_index[:, :, 0], self.rel_index[:, :, 1]]
IndexError: tensors used as indices must be long, byte or bool tensors
is it normal ?
Thanks a lot, Lu
Hi,
I tested the code again on pytorch/pytorch:1.9.0-cuda11.1-cudnn8-runtime
docker image, and the snippet above worked. Please check your Python/PyTorch version. If the issue is still not resolved after resolving the dependencies, please let me know again.
Or, this Colab notebook may be helpful.
Thank you for your support!
Closing this issue based on the comment above. Please feel free to reopen this issue if the problem still exists.
@xxxnell @luluenen I got the same error, and fixed it by modifying https://github.com/xxxnell/how-do-vits-work/blob/cea0635c77ca9a22882feb965995c0eac917ebdd/models/alternet.py#L49 to
i = torch.tensor(np.array([[x, y] for x in range(window_size) for y in range(window_size)]), dtype=torch.long)
Hi, awesome work and really good points about MSAs! I'm very much interested in the AlterNet mentioned in the paper(based on ResNet-50 and SwinTBlock), but I cant find the implementation of it in this repo. Did I miss? If not, can you release the code maybe?
Thanks a lot!