mmasana / FACIL

Framework for Analysis of Class-Incremental Learning with 12 state-of-the-art methods and 3 baselines.
https://arxiv.org/pdf/2010.15277.pdf
MIT License
524 stars 99 forks source link

How to integrate ViT in Networks? #46

Open Snimm opened 4 months ago

Snimm commented 4 months ago

I want to use Huggingface's ViTForImageClassification. How do I integrate it in FACIL? I want to load the pretrained model 'google/vit-base-patch16-224'. I have read the instructions to add networks in readme of networks. However I am still not sure how to implement it. How do I set "self.head_var = 'fc'" when head is changed by "model.classifier = nn.Linear(768, num_classes)"? How exactly will a class even be created in this case?

mmasana commented 4 months ago

Not sure I understand which is the exact issue. If the pretrained model has a head called classifier, you just need to define self.head_var = 'classifier' and it should be removed. Not self.head_var = 'fc' because there is no layer called fc. If you want, you can also remove the head directly from the pretrained model and use the argument --keep-existing-head (even though you removed it) so that nothing else is removed.