tensorflow / models

Models and examples built with TensorFlow
Other
77.04k stars 45.77k forks source link

Converting CenterNet MobileNetV2 to TFLite raises: Tensor's shape (256,) is not compatible with supplied shape (1,) #9834

Open tungnat97 opened 3 years ago

tungnat97 commented 3 years ago

Prerequisites

Please answer the following questions for yourself before submitting an issue.

1. The entire URL of the file you are using

CenterNet MobileNet V2

2. Describe the bug

When running following command: python3 models/research/object_detection/export_tflite_graph_tf2.py --pipeline_config_path centernet_mobilenetv2_fpn_kpts/pipeline.config --trained_checkpoint_dir centernet_mobilenetv2_fpn_kpts/checkpoint/ --output_directory centernet_mobilenetv2_fpn_kpts/tflite/ --centernet_include_keypoints true --keypoint_label_map_path centernet_mobilenetv2_fpn_kpts/label_map.txt --max_detections 1 --config_override "model { center_net { image_resizer { fixed_shape_resizer { height: 256 width: 256 } } } }", I get the following error message:

Traceback (most recent call last):
  File "models/research/object_detection/export_tflite_graph_tf2.py", line 161, in <module>
    app.run(main)
  File "/home/tung/.local/lib/python3.8/site-packages/absl/app.py", line 303, in run
    _run_main(main, args)
  File "/home/tung/.local/lib/python3.8/site-packages/absl/app.py", line 251, in _run_main
    sys.exit(main(argv))
  File "models/research/object_detection/export_tflite_graph_tf2.py", line 154, in main
    export_tflite_graph_lib_tf2.export_tflite_model(
  File "/home/tung/.local/lib/python3.8/site-packages/object_detection/export_tflite_graph_lib_tf2.py", line 366, in export_tflite_model
    concrete_function = detection_module.inference_fn.get_concrete_function(
  File "/home/tung/.local/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 1299, in get_concrete_function
    concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
  File "/home/tung/.local/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 1205, in _get_concrete_function_garbage_collected
    self._initialize(args, kwargs, add_initializers_to=initializers)
  File "/home/tung/.local/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 725, in _initialize
    self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
  File "/home/tung/.local/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2969, in _get_concrete_function_internal_garbage_collected
    graph_function, _ = self._maybe_define_function(args, kwargs)
  File "/home/tung/.local/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3361, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/home/tung/.local/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3196, in _create_graph_function
    func_graph_module.func_graph_from_py_func(
  File "/home/tung/.local/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 990, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/tung/.local/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 634, in wrapped_fn
    out = weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/home/tung/.local/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3887, in bound_method_wrapper
    return wrapped_fn(*args, **kwargs)
  File "/home/tung/.local/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 977, in wrapper
    raise e.ag_error_metadata.to_exception(e)
ValueError: in user code:

    /home/tung/.local/lib/python3.8/site-packages/object_detection/export_tflite_graph_lib_tf2.py:288 inference_fn  *
        prediction_dict = self._model.predict(image, None)
    /home/tung/.local/lib/python3.8/site-packages/object_detection/meta_architectures/center_net_meta_arch.py:3288 predict  *
        predictions[head_name] = [
    /home/tung/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py:1012 __call__  **
        outputs = call_fn(inputs, *args, **kwargs)
    /home/tung/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/sequential.py:389 call
        outputs = layer(inputs, **kwargs)
    /home/tung/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py:1008 __call__
        self._maybe_build(inputs)
    /home/tung/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py:2710 _maybe_build
        self.build(input_shapes)  # pylint:disable=not-callable
    /home/tung/.local/lib/python3.8/site-packages/tensorflow/python/keras/layers/convolutional.py:207 build
        self.bias = self.add_weight(
    /home/tung/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py:623 add_weight
        variable = self._add_variable_with_custom_getter(
    /home/tung/.local/lib/python3.8/site-packages/tensorflow/python/training/tracking/base.py:805 _add_variable_with_custom_getter
        new_variable = getter(
    /home/tung/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer_utils.py:130 make_variable
        return tf_variables.VariableV1(
    /home/tung/.local/lib/python3.8/site-packages/tensorflow/python/ops/variables.py:260 __call__
        return cls._variable_v1_call(*args, **kwargs)
    /home/tung/.local/lib/python3.8/site-packages/tensorflow/python/ops/variables.py:206 _variable_v1_call
        return previous_getter(
    /home/tung/.local/lib/python3.8/site-packages/tensorflow/python/ops/variables.py:67 getter
        return captured_getter(captured_previous, **kwargs)
    /home/tung/.local/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py:712 variable_capturing_scope
        v = UnliftedInitializerVariable(
    /home/tung/.local/lib/python3.8/site-packages/tensorflow/python/ops/variables.py:264 __call__
        return super(VariableMetaclass, cls).__call__(*args, **kwargs)
    /home/tung/.local/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py:227 __init__
        initial_value = initial_value()
    /home/tung/.local/lib/python3.8/site-packages/tensorflow/python/training/tracking/base.py:81 __call__
        return CheckpointInitialValue(
    /home/tung/.local/lib/python3.8/site-packages/tensorflow/python/training/tracking/base.py:117 __init__
        self.wrapped_value.set_shape(shape)
    /home/tung/.local/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:1215 set_shape
        raise ValueError(

    ValueError: Tensor's shape (256,) is not compatible with supplied shape (1,)

3. Steps to reproduce

TFLite model converted properly.

5. Additional context

Include any logs that would be helpful to diagnose the problem.

6. System information

Abhishekvats1997 commented 3 years ago

Hi this might be a bit unrelated but could you provide any insight on the config you used while training the centernet mobilenetv2. I am struggling on this for days, the training loss seems to decrease and everything seems fine while training but the model fails to give predicitions on val set and the mAP is close to zero.