Closed BasileR closed 3 years ago
I have the similar problem. Don't know how to resolve it~
Hi @BasileR and @reallm , Thanks for your question, This is a strange problem as it works well for DeiT. Did you check if you have the right performance on ImageNet to see if there is any problem when loading the weights? Best, Hugo
Hi @TouvronHugo ,
Thank you for your answer. I tested cait_XXS_224 on ImageNet and I got 78% of accuracy, so I don't think that there is a problem with loading the weights. Here is my code :
v = cait_XXS_224(pretrained = False)
checkpoint = torch.load('logs/ImageNet/XXS24_24.pth')
checkpoint_no_module = {}
for k in v.state_dict().keys():
checkpoint_no_module[k] = checkpoint["model"]['module.'+k]
# end for
v.load_state_dict(checkpoint_no_module)
v.head = nn.Linear(in_features = v.head.in_features, out_features = 10, bias = True)
v.to(device)
train_model(...) ## this function works with DeiT
Do you think that the problem could come from the modification of the head ? I also tried to train CaiT with 10 classes (directly) :
cait_XXS_224(num_classes = 10 , pretrained = False)
and loaded the weights expect for the head, but I had the same problem as before.
Thank you for your help,
Basile
Hi @BasileR , I don't think that the problem come from the head could you give me your training hparams ? Best, Hugo
@TouvronHugo ,
lr = 0.01 and 0.001 momentum = 0.9 decay = 5e-4 optimizer : SGD Optimizer : CosineAnnealingLR batch_size = 256
Basile
Hi @BasileR , Thank you for your answer, Do you use stochastic depth? Did you try to do the fine-tuning with AdamW ? Best, Hugo
Hello @TouvronHugo,
I tried with and without stochastic depth and with AdamW, but it did not work each time. I have tested DeiT, T2T-ViT and DeepVit (from other repos) with pretrained weights and all of them worked with the same scripts. The only thing that changes is this module thing.
Thank you,
Basile
Hi @BasileR , Thanks for your answer. I don't really see where the problem comes from. Do you have the same problem with CaiT models from this repo: timm ? Best, Hugo
As there is no more activity I am closing the issue but feel free to re-open it if needed
Sorry for bothering you. I just met a similar problem and a smaller learning rate works it out. To be specific. I tried to fine-tune a cait_S24_224 on JSRT dataset, which is a small dataset with 247-sample chest x-ray images. AdamW, and learning rate was set as 1e-2, 1e-3, 1e-4 but the model did not converge with AUC score of around 0.6 or even worse. Then I noticed this issue: https://github.com/lucidrains/vit-pytorch/issues/34 After setting the learning rate as 3e-5, cait_S24_244 finally converged with AUC score of 0.88 (overfitted).
Hello,
Thanks a lot for this this great repo. I'm currently doing transfer learning with Cait XXS24 224, but I have a problem when loading the pretrained weights : when I train cait on the new task, the accuracy will start from 10 (random) and won't increase. I tried to train small deit on this task with transfer learning, and this time it worked well (with the same training functions). Do you have any idea what could be the problem here ?
Here is the code to load weigths (actually it is the one that you provide) :
v = cait_XXS_224(pretrained = False) checkpoint = torch.load('logs/ImageNet/XXS24_24.pth') checkpoint_no_module = {} for k in v.state_dict().keys(): checkpoint_no_module[k] = checkpoint["model"]['module.'+k] v.load_state_dict(checkpoint_no_module)
I'm using torch 1.7.1 and timm 0.4.5.