SHI-Labs / Compact-Transformers

Escaping the Big Data Paradigm with Compact Transformers, 2021 (Train your Vision Transformers in 30 mins on CIFAR-10 with a single GPU!)
https://arxiv.org/abs/2104.05704
Apache License 2.0
495 stars 77 forks source link

Config for training Flowers SOTA #48

Closed wsgharvey closed 2 years ago

wsgharvey commented 2 years ago

Hi,

I'm trying to figure out how to train your model to achieve the SOTA accuracy you report on Flowers102. It seems like using finetuned/cct_14-7x2_flowers102.yml will download the model with the 99.76% test accuracy you report, but I can't find any config files which actually train this model from scratch (or from e.g. an ImageNet checkpoint if you use that). Do you mind pointing me to any config files for this that I might have missed, or else to a description of the training procedure for your SOTA Flowers102 model so that I can try to reproduce it?

Thanks for your help, Will

alihassanijr commented 2 years ago

Hello, Thank you for your interest. Actually this file will download the ImageNet checkpoint, not the Flowers checkpoint. The model specified in it is cct_14_7x2_384 which points to the ImageNet checkpoint. cct_14_7x2_384_fl will download the Flowers checkpoint. Could you share more details? Perhaps the command you used?

wsgharvey commented 2 years ago

Thanks for the quick response! I see the issue - I'd been overriding the --model as follows with the command python train.py path/to/flower_data --config=configs/finetuned/cct_14-7x2_flowers102.yml --model=cct_14_7x2_384_fl.

However, when I use the cct_14_7x2_384 model I get a different error: the command python train.py path/to/flower_data --config=configs/finetuned/cct_14-7x2_flowers102.yml gives me the error

(transformers) wsgh@alexandria:/ubc/cs/research/plai-scratch/wsgh/Compact-Transformers$ python train.py flower_data --config=configs/finetuned/cct_14-7x2_flowers102.yml
Training with a single process on 1 GPUs.
Traceback (most recent call last):
  File "train.py", line 806, in <module>
    main()
  File "train.py", line 366, in main
    checkpoint_path=args.initial_checkpoint)
  File "/ubc/cs/research/plai-scratch/wsgh/envs/transformers/lib/python3.6/site-packages/timm/models/factory.py", line 81, in create_model
    model = create_fn(pretrained=pretrained, **kwargs)
  File "/ubc/cs/research/plai-scratch/wsgh/Compact-Transformers/src/cct.py", line 335, in cct_14_7x2_384
    *args, **kwargs)
  File "/ubc/cs/research/plai-scratch/wsgh/Compact-Transformers/src/cct.py", line 137, in cct_14
    *args, **kwargs)
  File "/ubc/cs/research/plai-scratch/wsgh/Compact-Transformers/src/cct.py", line 109, in _cct
    model.load_state_dict(state_dict)
  File "/ubc/cs/research/plai-scratch/wsgh/envs/transformers/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1483, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for CCT:
        size mismatch for classifier.fc.weight: copying a param with shape torch.Size([1000, 384]) from checkpoint, the shape in current model is torch.Size([102, 384]).
        size mismatch for classifier.fc.bias: copying a param with shape torch.Size([1000]) from checkpoint, the shape in current model is torch.Size([102]).
alihassanijr commented 2 years ago

Thank you for sharing. This is certainly a bug from our side, because the ImageNet checkpoint has 1000 outputs while Flowers has 102, and this should have only raised a warning, not an error. To be clear, model should not be overridden. The config files contain the settings needed to reproduce results.

I will try and fix this today.

wsgharvey commented 2 years ago

Sounds good, thanks for the help!

alihassanijr commented 2 years ago

Please try it now, let me know if it works.

wsgharvey commented 2 years ago

Yes, it's working for me now. Cheers!