Closed wsgharvey closed 2 years ago
Hello,
Thank you for your interest.
Actually this file will download the ImageNet checkpoint, not the Flowers checkpoint.
The model specified in it is cct_14_7x2_384
which points to the ImageNet checkpoint.
cct_14_7x2_384_fl
will download the Flowers checkpoint.
Could you share more details? Perhaps the command you used?
Thanks for the quick response! I see the issue - I'd been overriding the --model
as follows with the command python train.py path/to/flower_data --config=configs/finetuned/cct_14-7x2_flowers102.yml --model=cct_14_7x2_384_fl
.
However, when I use the cct_14_7x2_384
model I get a different error: the command python train.py path/to/flower_data --config=configs/finetuned/cct_14-7x2_flowers102.yml
gives me the error
(transformers) wsgh@alexandria:/ubc/cs/research/plai-scratch/wsgh/Compact-Transformers$ python train.py flower_data --config=configs/finetuned/cct_14-7x2_flowers102.yml
Training with a single process on 1 GPUs.
Traceback (most recent call last):
File "train.py", line 806, in <module>
main()
File "train.py", line 366, in main
checkpoint_path=args.initial_checkpoint)
File "/ubc/cs/research/plai-scratch/wsgh/envs/transformers/lib/python3.6/site-packages/timm/models/factory.py", line 81, in create_model
model = create_fn(pretrained=pretrained, **kwargs)
File "/ubc/cs/research/plai-scratch/wsgh/Compact-Transformers/src/cct.py", line 335, in cct_14_7x2_384
*args, **kwargs)
File "/ubc/cs/research/plai-scratch/wsgh/Compact-Transformers/src/cct.py", line 137, in cct_14
*args, **kwargs)
File "/ubc/cs/research/plai-scratch/wsgh/Compact-Transformers/src/cct.py", line 109, in _cct
model.load_state_dict(state_dict)
File "/ubc/cs/research/plai-scratch/wsgh/envs/transformers/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1483, in load_state_dict
self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for CCT:
size mismatch for classifier.fc.weight: copying a param with shape torch.Size([1000, 384]) from checkpoint, the shape in current model is torch.Size([102, 384]).
size mismatch for classifier.fc.bias: copying a param with shape torch.Size([1000]) from checkpoint, the shape in current model is torch.Size([102]).
Thank you for sharing. This is certainly a bug from our side, because the ImageNet checkpoint has 1000 outputs while Flowers has 102, and this should have only raised a warning, not an error. To be clear, model should not be overridden. The config files contain the settings needed to reproduce results.
I will try and fix this today.
Sounds good, thanks for the help!
Please try it now, let me know if it works.
Yes, it's working for me now. Cheers!
Hi,
I'm trying to figure out how to train your model to achieve the SOTA accuracy you report on Flowers102. It seems like using
finetuned/cct_14-7x2_flowers102.yml
will download the model with the 99.76% test accuracy you report, but I can't find any config files which actually train this model from scratch (or from e.g. an ImageNet checkpoint if you use that). Do you mind pointing me to any config files for this that I might have missed, or else to a description of the training procedure for your SOTA Flowers102 model so that I can try to reproduce it?Thanks for your help, Will