tensorflow / models

Models and examples built with TensorFlow
Other
76.95k stars 45.8k forks source link

eager_few_shot_od_training_tf2_colab classification head restoring error #9382

Open Jaredeco opened 3 years ago

Jaredeco commented 3 years ago

Hello, I am trying to train the object detection model as in the eager_few_shot_od_training_tf2_colab file but the only change that I do is that I restore the classification head of the model because I want to train it on a dataset with three classes later except this everything in that file stays the same as in repo on github, however, when I uncomment that line to restore classification head as you can see in this code snippet and as there is suggested to do, I get an error.

# Set up object-based checkpoint restore --- RetinaNet has two prediction
# `heads` --- one for classification, the other for box regression.  We will
# restore the box regression head but initialize the classification head
# from scratch (we show the omission below by commenting out the line that
# we would add if we wanted to restore both heads)
fake_box_predictor = tf.compat.v2.train.Checkpoint(
    _base_tower_layers_for_heads=detection_model._box_predictor._base_tower_layers_for_heads,
    _prediction_heads=detection_model._box_predictor._prediction_heads, #   <-----------------I uncommented this line
    #    (i.e., the classification head that we *will not* restore)
    _box_prediction_head=detection_model._box_predictor._box_prediction_head,
    )
fake_model = tf.compat.v2.train.Checkpoint(
          _feature_extractor=detection_model._feature_extractor,
          _box_predictor=fake_box_predictor)
ckpt = tf.compat.v2.train.Checkpoint(model=fake_model)
ckpt.restore(checkpoint_path).expect_partial()

# Run model through a dummy image so that variables are created
image, shapes = detection_model.preprocess(tf.zeros([1, 640, 640, 3]))
prediction_dict = detection_model.predict(image, shapes)
_ = detection_model.postprocess(prediction_dict, shapes)
print('Weights restored!')

This is the error that I get:

ValueError                                Traceback (most recent call last)
    <ipython-input-7-96e77f9f8468> in <module>
         24 
         25 image, shapes = detection_model.preprocess(tf.zeros([1, 640, 640, 3]))
    ---> 26 prediction_dict = detection_model.predict(image, shapes)
         27 _ = detection_model.postprocess(prediction_dict, shapes)
         28 print('Weights restored!')

C:\Python\lib\site-packages\object_detection\meta_architectures\ssd_meta_arch.py in predict(self, preprocessed_inputs, true_image_shapes)
    589     self._anchors = box_list_ops.concatenate(boxlist_list)
    590     if self._box_predictor.is_keras_model:
--> 591       predictor_results_dict = self._box_predictor(feature_maps)
    592     else:
    593       with slim.arg_scope([slim.batch_norm],

C:\Python\lib\site-packages\tensorflow\python\keras\engine\base_layer.py in __call__(self, *args, **kwargs)
    983 
    984         with ops.enable_auto_cast_variables(self._compute_dtype_object):
--> 985           outputs = call_fn(inputs, *args, **kwargs)
    986 
    987         if self._activity_regularizer:

C:\Python\lib\site-packages\object_detection\core\box_predictor.py in call(self, image_features, **kwargs)
    200           feature map in the input `image_features` list.
    201     """
--> 202     return self._predict(image_features, **kwargs)
    203 
    204   @abstractmethod

C:\Python\lib\site-packages\object_detection\predictors\convolutional_keras_box_predictor.py in _predict(self, image_features, **kwargs)
    482               self._base_tower_layers_for_heads[head_name][index],
    483               image_feature)
--> 484         prediction = head_obj(head_tower_feature)
    485         predictions[head_name].append(prediction)
    486     return predictions

C:\Python\lib\site-packages\tensorflow\python\keras\engine\base_layer.py in __call__(self, *args, **kwargs)
    983 
    984         with ops.enable_auto_cast_variables(self._compute_dtype_object):
--> 985           outputs = call_fn(inputs, *args, **kwargs)
    986 
    987         if self._activity_regularizer:

C:\Python\lib\site-packages\object_detection\predictors\heads\head.py in call(self, features)
     67   def call(self, features):
     68     """The Keras model call will delegate to the `_predict` method."""
---> 69     return self._predict(features)
     70 
     71   @abstractmethod

C:\Python\lib\site-packages\object_detection\predictors\heads\keras_class_head.py in _predict(self, features)
    339     for layer in self._class_predictor_layers:
    340       class_predictions_with_background = layer(
--> 341           class_predictions_with_background)
    342     batch_size = features.get_shape().as_list()[0]
    343     if batch_size is None:

C:\Python\lib\site-packages\tensorflow\python\keras\engine\base_layer.py in __call__(self, *args, **kwargs)
    980       with ops.name_scope_v2(name_scope):
    981         if not self.built:
--> 982           self._maybe_build(inputs)
    983 
    984         with ops.enable_auto_cast_variables(self._compute_dtype_object):

C:\Python\lib\site-packages\tensorflow\python\keras\engine\base_layer.py in _maybe_build(self, inputs)
   2641         # operations.
   2642         with tf_utils.maybe_init_scope(self):
-> 2643           self.build(input_shapes)  # pylint:disable=not-callable
   2644       # We must set also ensure that the layer is marked as built, and the build
   2645       # shape is stored since user defined build functions may not be calling

C:\Python\lib\site-packages\tensorflow\python\keras\layers\convolutional.py in build(self, input_shape)
    202         constraint=self.kernel_constraint,
    203         trainable=True,
--> 204         dtype=self.dtype)
    205     if self.use_bias:
    206       self.bias = self.add_weight(

C:\Python\lib\site-packages\tensorflow\python\keras\engine\base_layer.py in add_weight(self, name, shape, dtype, initializer, regularizer, trainable, constraint, partitioner, use_resource, synchronization, aggregation, **kwargs)
    612         synchronization=synchronization,
    613         aggregation=aggregation,
--> 614         caching_device=caching_device)
    615     if regularizer is not None:
    616       # TODO(fchollet): in the future, this should be handled at the

C:\Python\lib\site-packages\tensorflow\python\training\tracking\base.py in _add_variable_with_custom_getter(self, name, shape, dtype, initializer, getter, overwrite, **kwargs_for_getter)
    729         # there is nothing to restore.
    730         checkpoint_initializer = self._preload_simple_restoration(
--> 731             name=name, shape=shape)
    732       else:
    733         checkpoint_initializer = None

C:\Python\lib\site-packages\tensorflow\python\training\tracking\base.py in _preload_simple_restoration(self, name, shape)
    796         key=lambda restore: restore.checkpoint.restore_uid)
    797     return CheckpointInitialValue(
--> 798         checkpoint_position=checkpoint_position, shape=shape)
    799 
    800   def _track_trackable(self, trackable, name, overwrite=False):

C:\Python\lib\site-packages\tensorflow\python\training\tracking\base.py in __init__(self, checkpoint_position, shape)
     73       # We need to set the static shape information on the initializer if
     74       # possible so we don't get a variable with an unknown shape.
---> 75       self.wrapped_value.set_shape(shape)
     76     self._checkpoint_position = checkpoint_position
     77 

C:\Python\lib\site-packages\tensorflow\python\framework\ops.py in set_shape(self, shape)
   1207       raise ValueError(
   1208           "Tensor's shape %s is not compatible with supplied shape %s" %
-> 1209           (self.shape, shape))
   1210 
   1211   # Methods not supported / implemented for Eager Tensors.

ValueError: Tensor's shape (3, 3, 256, 546) is not compatible with supplied shape (3, 3, 256, 24)
saikumarchalla commented 3 years ago

Was able to reproduce the issue .Please find the gist here. Thanks!

Jaredeco commented 3 years ago

Was able to reproduce the issue .Please find the gist here. Thanks!

Hello thank you for your help, but when I run your notebook I get the same error as before. Can you help me with this, please?

Jaredeco commented 3 years ago

Please does anyone know the answer to this error, I have been stuck fixing for a long time and still did not make it work.

iamarchisha commented 3 years ago

@Jaredeco That error is occurring because you are trying to restore a checkpoint that is trained on num_classes=90. You can experiment the following to see for yourself:

  1. Uncomment _prediction_heads=detection_model._box_predictor._prediction_heads, and make num_classes=90 --> The model will be able to restore weights.
  2. Uncomment _prediction_heads=detection_model._box_predictor._prediction_heads, and make num_classes=1 --> The model will not be able to restore weights.
  3. Comment _prediction_heads=detection_model._box_predictor._prediction_heads, and make num_classes=1 --> The model will be able to restore weights.

As mentioned here for issue #9133, you need to make changes in the checkpoint file (change number of classes) if you are trying to restore it.