google-research / simclr

SimCLRv2 - Big Self-Supervised Models are Strong Semi-Supervised Learners
https://arxiv.org/abs/2006.10029
Apache License 2.0
4.08k stars 623 forks source link

Shape doesn't match when performing linear eval #42

Closed mengye-ren closed 4 years ago

mengye-ren commented 4 years ago

Hi, thank you for the code release!

I encounter the following error when performing linear eval on CIFAR.

Pretraining:

python run.py --train_mode=pretrain   --train_batch_size=512 --train_epochs=1000   --learning_rate=1.0 --weight_decay=1e-4 --temperature=0.5   --dataset=cifar10 --image_size=32 --eval_split=test --resnet_depth=18   --use_blur=False --color_jitter_strength=0.5   --model_dir=/mnt/research/results/simclr/simclr_test --use_tpu=False

Linear eval:

python run.py --mode=train_then_eval --train_mode=finetune   --fine_tune_after_block=4 --zero_init_logits_layer=True   --variable_schema='(?!global_step|(?:.*/|^)Momentum|head)'   --global_bn=False --optimizer=momentum --learning_rate=0.1 --weight_decay=0.0   --train_epochs=100 --train_batch_size=512 --warmup_epochs=0   --dataset=cifar10 --image_size=32 --eval_split=test --resnet_depth=18   --checkpoint=/mnt/research/results/simclr/simclr_test --model_dir=/mnt/research/results/simclr/simclr_test_ft --use_tpu=False
I0625 13:45:52.051569 140622183225152 evaluation.py:276] Finished evaluation at 2020-06-25-13:45:52
INFO:tensorflow:Saving dict for global step 9766: contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 9766, label_top_1_accuracy = 0.8248, label_top_5_accuracy = 0.9829, loss = 0.5490037, regularization_loss = 0.0
I0625 13:45:52.051712 140622183225152 estimator.py:2053] Saving dict for global step 9766: contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 9766, label_top_1_accuracy = 0.8248, label_top_5_accuracy = 0.9829, loss = 0.5490037, regularization_loss = 0.0
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 9766: /mnt/research/results/simclr/simclr_test_ft/model.ckpt-9766
I0625 13:45:52.182560 140622183225152 estimator.py:2113] Saving 'checkpoint_path' summary for global step 9766: /mnt/research/results/simclr/simclr_test_ft/model.ckpt-9766
INFO:tensorflow:evaluation_loop marked as finished
I0625 13:45:52.182964 140622183225152 error_handling.py:108] evaluation_loop marked as finished
WARNING:tensorflow:From /home/mren/miniconda3/envs/tf/lib/python3.6/site-packages/tensorflow_hub/saved_model_lib.py:110: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.
W0625 13:45:52.510126 140622183225152 deprecation.py:323] From /home/mren/miniconda3/envs/tf/lib/python3.6/site-packages/tensorflow_hub/saved_model_lib.py:110: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.
Traceback (most recent call last):
  File "run.py", line 435, in <module>
    app.run(main)
  File "/home/mren/miniconda3/envs/tf/lib/python3.6/site-packages/absl/app.py", line 299, in run
    _run_main(main, args)
  File "/home/mren/miniconda3/envs/tf/lib/python3.6/site-packages/absl/app.py", line 250, in _run_main
    sys.exit(main(argv))
  File "run.py", line 430, in main
    num_classes=num_classes)
  File "run.py", line 343, in perform_evaluation
    checkpoint_path=checkpoint_path)
  File "run.py", line 293, in build_hub_module
    name_transform_fn=None)
  File "/home/mren/miniconda3/envs/tf/lib/python3.6/site-packages/tensorflow_hub/module_spec.py", line 80, in export
    export_module_spec(self, path, checkpoint_path, name_transform_fn)
  File "/home/mren/miniconda3/envs/tf/lib/python3.6/site-packages/tensorflow_hub/module.py", line 74, in export_module_spec
    tf_v1.train.init_from_checkpoint(checkpoint_path, assign_map)
  File "/home/mren/miniconda3/envs/tf/lib/python3.6/site-packages/tensorflow_core/python/training/checkpoint_utils.py", line 291, in init_from_checkpoint
    init_from_checkpoint_fn)
  File "/home/mren/miniconda3/envs/tf/lib/python3.6/site-packages/tensorflow_core/python/distribute/distribute_lib.py", line 1949, in merge_call
    return self._merge_call(merge_fn, args, kwargs)
  File "/home/mren/miniconda3/envs/tf/lib/python3.6/site-packages/tensorflow_core/python/distribute/distribute_lib.py", line 1956, in _merge_call
    return merge_fn(self._strategy, *args, **kwargs)
  File "/home/mren/miniconda3/envs/tf/lib/python3.6/site-packages/tensorflow_core/python/training/checkpoint_utils.py", line 286, in <lambda>
    ckpt_dir_or_file, assignment_map)
  File "/home/mren/miniconda3/envs/tf/lib/python3.6/site-packages/tensorflow_core/python/training/checkpoint_utils.py", line 329, in _init_from_checkpoint
    tensor_name_in_ckpt, str(variable_map[tensor_name_in_ckpt])
ValueError: Shape of variable module/head_supervised/linear_layer/dense/kernel:0 ((512, 10)) doesn't match with shape of tensor head_supervised/linear_layer/dense/kernel ([128, 10]) from checkpoint reader.
chentingpc commented 4 years ago

This seems to be the same issue as in #39. It is due to the ft_proj_selector introduced recently, and you should be able to fix it by adding a flag --ft_proj_selector=0. Please let me know if it fixes it.

mengye-ren commented 4 years ago

Thanks for the quick response. Adding this flag fixes the problem.