google-research / big_transfer

Official repository for the "Big Transfer (BiT): General Visual Representation Learning" paper.
https://arxiv.org/abs/1912.11370
Apache License 2.0
1.5k stars 175 forks source link

TF2 weights shape wrong for 2 architectures #47

Open shkarupa-alex opened 3 years ago

shkarupa-alex commented 3 years ago

I've tried to load all provided TF2 model weights and found that 2 of them could not be loaded:

All other weights loaded without problems. Sample code to reproduce: https://colab.research.google.com/drive/1s2QtVgrj2HrDs64xGMi_GsOaFR3i95v0?usp=sharing

chrstn-hntschl commented 3 years ago

I actually found none of the provided BiT-S- models to be compatible - i.e. they all cause an incompatible shape exception when e.g. trying to fine-tune on CIFAR10 dataset (while the BiT-M- models work flawlessly!). Is this a known issue? Is there anything that needs to be adjusted?

@shkarupa-alex when I use the suggested command to fine-tune on CIFAR10: python3 -m bit_tf2.train --name cifar10_date +%F_%H%M%S--model BiT-S-R50x1 --logdir /tmp/bit_logs --dataset cifar10, different from your findings, I receive a ValueError: Shapes (2048, 21843) and (2048, 1000) are incompatible.

chrstn-hntschl commented 3 years ago

My bad for not seeing the obvious: the difference in both shapes is due to the BiT-M base models expecting 21843 outputs (i.e. number of classes in ImageNet-21k), whereas BiT-S models expect 1000 outputs (i.e. number of classes in ILSVRC2012). In the current implementation, num_outputs is hardcoded to num_outputs=21843. This needs to be selected based on the used pre-trained model, e.g. by adding

NUM_OUTPUTS = {
    k: 1000 if "-S-" in k else
       21843
    for k in KNOWN_MODELS
}

in https://github.com/google-research/big_transfer/blob/master/bit_tf2/train.py.