Open shkarupa-alex opened 3 years ago
I actually found none of the provided BiT-S- models to be compatible - i.e. they all cause an incompatible shape exception when e.g. trying to fine-tune on CIFAR10 dataset (while the BiT-M- models work flawlessly!). Is this a known issue? Is there anything that needs to be adjusted?
@shkarupa-alex when I use the suggested command to fine-tune on CIFAR10: python3 -m bit_tf2.train --name cifar10_
date +%F_%H%M%S--model BiT-S-R50x1 --logdir /tmp/bit_logs --dataset cifar10
, different from your findings, I receive a ValueError: Shapes (2048, 21843) and (2048, 1000) are incompatible
.
My bad for not seeing the obvious: the difference in both shapes is due to the BiT-M base models expecting 21843 outputs (i.e. number of classes in ImageNet-21k), whereas BiT-S models expect 1000 outputs (i.e. number of classes in ILSVRC2012). In the current implementation, num_outputs
is hardcoded to num_outputs=21843
. This needs to be selected based on the used pre-trained model, e.g. by adding
NUM_OUTPUTS = {
k: 1000 if "-S-" in k else
21843
for k in KNOWN_MODELS
}
in https://github.com/google-research/big_transfer/blob/master/bit_tf2/train.py.
I've tried to load all provided TF2 model weights and found that 2 of them could not be loaded:
ValueError: Cannot assign to variable standardized_conv2d/kernel:0 due to variable shape (7, 7, 3, 192) and value shape (7, 7, 3, 64) are incompatible
Cannot assign to variable standardized_conv2d/kernel:0 due to variable shape (7, 7, 3, 64) and value shape (7, 7, 3, 256) are incompatible
All other weights loaded without problems. Sample code to reproduce: https://colab.research.google.com/drive/1s2QtVgrj2HrDs64xGMi_GsOaFR3i95v0?usp=sharing