facebookresearch / deit

Official DeiT repository
Apache License 2.0
4.07k stars 556 forks source link

No learning when transfer learning with Cait XXS24 224 #93

Closed BasileR closed 3 years ago

BasileR commented 3 years ago

Hello,

Thanks a lot for this this great repo. I'm currently doing transfer learning with Cait XXS24 224, but I have a problem when loading the pretrained weights : when I train cait on the new task, the accuracy will start from 10 (random) and won't increase. I tried to train small deit on this task with transfer learning, and this time it worked well (with the same training functions). Do you have any idea what could be the problem here ?

Here is the code to load weigths (actually it is the one that you provide) :

v = cait_XXS_224(pretrained = False) checkpoint = torch.load('logs/ImageNet/XXS24_24.pth') checkpoint_no_module = {} for k in v.state_dict().keys(): checkpoint_no_module[k] = checkpoint["model"]['module.'+k] v.load_state_dict(checkpoint_no_module)

I'm using torch 1.7.1 and timm 0.4.5.

reallm commented 3 years ago

I have the similar problem. Don't know how to resolve it~

TouvronHugo commented 3 years ago

Hi @BasileR and @reallm , Thanks for your question, This is a strange problem as it works well for DeiT. Did you check if you have the right performance on ImageNet to see if there is any problem when loading the weights? Best, Hugo

BasileR commented 3 years ago

Hi @TouvronHugo ,

Thank you for your answer. I tested cait_XXS_224 on ImageNet and I got 78% of accuracy, so I don't think that there is a problem with loading the weights. Here is my code :

v = cait_XXS_224(pretrained = False) checkpoint = torch.load('logs/ImageNet/XXS24_24.pth') checkpoint_no_module = {} for k in v.state_dict().keys(): checkpoint_no_module[k] = checkpoint["model"]['module.'+k] # end for v.load_state_dict(checkpoint_no_module) v.head = nn.Linear(in_features = v.head.in_features, out_features = 10, bias = True)

v.to(device) train_model(...) ## this function works with DeiT

Do you think that the problem could come from the modification of the head ? I also tried to train CaiT with 10 classes (directly) : cait_XXS_224(num_classes = 10 , pretrained = False) and loaded the weights expect for the head, but I had the same problem as before.

Thank you for your help,

Basile

TouvronHugo commented 3 years ago

Hi @BasileR , I don't think that the problem come from the head could you give me your training hparams ? Best, Hugo

BasileR commented 3 years ago

@TouvronHugo ,

lr = 0.01 and 0.001 momentum = 0.9 decay = 5e-4 optimizer : SGD Optimizer : CosineAnnealingLR batch_size = 256

Basile

TouvronHugo commented 3 years ago

Hi @BasileR , Thank you for your answer, Do you use stochastic depth? Did you try to do the fine-tuning with AdamW ? Best, Hugo

BasileR commented 3 years ago

Hello @TouvronHugo,

I tried with and without stochastic depth and with AdamW, but it did not work each time. I have tested DeiT, T2T-ViT and DeepVit (from other repos) with pretrained weights and all of them worked with the same scripts. The only thing that changes is this module thing.

Thank you,

Basile

TouvronHugo commented 3 years ago

Hi @BasileR , Thanks for your answer. I don't really see where the problem comes from. Do you have the same problem with CaiT models from this repo: timm ? Best, Hugo

TouvronHugo commented 3 years ago

As there is no more activity I am closing the issue but feel free to re-open it if needed

KoalaSheep commented 3 years ago

Sorry for bothering you. I just met a similar problem and a smaller learning rate works it out. To be specific. I tried to fine-tune a cait_S24_224 on JSRT dataset, which is a small dataset with 247-sample chest x-ray images. AdamW, and learning rate was set as 1e-2, 1e-3, 1e-4 but the model did not converge with AUC score of around 0.6 or even worse. Then I noticed this issue: https://github.com/lucidrains/vit-pytorch/issues/34 After setting the learning rate as 3e-5, cait_S24_244 finally converged with AUC score of 0.88 (overfitted).