lucidrains / vit-pytorch

Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch
MIT License
19.54k stars 2.95k forks source link

MAE using pretrained VIT #196

Open Songloading opened 2 years ago

Songloading commented 2 years ago

Hi There,

I am currently trying to fine-tune an MAE based on pretrained VIT from timm. However, when I do:

v = timm.create_model('vit_base_patch16_224', pretrained=True)
num_ftrs = v.head.in_features
v.head = nn.Linear(num_ftrs, 2) 
model = MAE(
    encoder = v,
    masking_ratio = 0.75,   # the paper recommended 75% masked patches
    decoder_dim = 512,      # paper showed good results with just 512
    decoder_depth = 6       # anywhere from 1 to 8
)

I got "AttributeError: 'VisionTransformer' object has no attribute 'pos_embedding'" It seems that timm model is not compatible with the MAE implementation. Can this be easily fixed or I will have to change the internal implementation of MAE?

lucidrains commented 2 years ago

@Songloading i have no idea! i'm not really familiar with timm - perhaps you can ask Ross about it?

Songloading commented 2 years ago

@lucidrains ok. Any ideas on what else pretrained VIT besides their implementation or pretrained MAE?

mw9385 commented 1 year ago

Did you solve the issue? @Songloading