lukemelas / PyTorch-Pretrained-ViT

Vision Transformer (ViT) in PyTorch
770 stars 124 forks source link

cannot download pre-trained model #3

Closed muaz1994 closed 3 years ago

muaz1994 commented 3 years ago

Hi!

When I try to download the pre-trained model, i get the error: 'NoneType' object has no attribute 'group'

This is what I do:

from pytorch_pretrained_vit import ViT
model = ViT('B_16_imagenet1k', pretrained=True)

I'm using pytorch 1.1. Does the version matter? Because it works on Google Colab but not my machine.

fawazsammani commented 3 years ago

Hi @muaz1994 . I had a similar issue a few days back, and it appears it has something to do with the torch hub load from url function. To solve this, I had to download the pre-trained model manually and load it:

model = ViT('L_16_imagenet1k')
checkpoint = torch.load('L_16_imagenet1k.pth')
model.load_state_dict(checkpoint)

The links to download the models can be found in the configs.py.

If you're using a PyTorch version older than 1.5, you also have to replace F.gelu in transformer.py/PositionWiseFeedForward class with: gelu(x) where gelu is defined below:

def gelu(x):
    return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
muaz1994 commented 3 years ago

Thanks @fawazsammani . This worked.