lukemelas / PyTorch-Pretrained-ViT

Vision Transformer (ViT) in PyTorch
770 stars 124 forks source link

Load from google's saved weights. #9

Open anicolson opened 3 years ago

anicolson commented 3 years ago

Hi, I was wondering if there would be a way to load the weights from google's saved checkpoint directly, instead of having to download them.

I see that in the init of ViT is:

            load_pretrained_weights(
                self, name, 
                load_first_conv=(in_channels == pretrained_num_channels),
                load_fc=(num_classes == pretrained_num_classes),
                load_repr_layer=load_repr_layer,
                resize_positional_embedding=(image_size != pretrained_image_size),
            )

So, the weights_path can't be given to ViT. Could this be ammended?

Thank you for your help.

anicolson commented 3 years ago

Sorry, just realized that you don't use Google's available checkpoints.