gevero / enet_tensorflow

Enet implementation in tensorflow
14 stars 3 forks source link

can't convert ENet model to tflite and TRT #1

Closed UcefMountacer closed 2 years ago

UcefMountacer commented 3 years ago

Hi,

Thanks for the code. I liked the notebook it's very well made.

I tried to convert this model to tflite and trt after. After converting to tflite int16, the interpreter crashes when invoking. I have seen similar behaviour with a converted Mask-RCNN model.

Does it come from the fact that this model has layers that are not yet supported by tflite ? knowing that I used supported ops.

converter.target_spec.supported_ops = [
  tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
  tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
]
UcefMountacer commented 3 years ago

Update

Problem occurs on GPU not the CPU.

For Tensoflow TRT I can't optimize due to some layers that can't be serialized.

I think I should close the issue since the problem is linked to testing tflite on GPU.

gevero commented 3 years ago

Hi

Which layers where causing the problems? Was it the MaxUnpool? If it is the case maybe we can replace with a deconvolution.

UcefMountacer commented 3 years ago

When trying to save it in SavedModel format it doesn't work...

print('Saving model...')
EnetEndToEnd.save('/content/')

I get the following error :

ValueError: Model <models.EnetModel object at 0x7fa6343b3b70> cannot be saved because the input shapes have not been set. Usually, input shapes are automatically determined from calling .fit() or .predict(). To manually set the shapes, call model.build(input_shape).

This unknown problem causes issues since TRT in tensorflow 2 needs this type of format.

So I did the version tf 1 workflow by freezing the model first and then using the tf v1 workflow (which is really annoying).

# concrete function
full_model = tf.function(lambda x: EnetEndToEnd(x))
full_model = full_model.get_concrete_function(tf.TensorSpec((1, 360, 480, 3), np.float32))
# Get frozen graph def
frozen_func = convert_variables_to_constants_v2(full_model)
frozen_func.graph.as_graph_def()
print("Frozen model inputs: ")
print(frozen_func.inputs)
print("Frozen model outputs: ")
print(frozen_func.outputs)
tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
                  logdir='',
                  name="frozen_graph.pb",
                  as_text=False)
tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
                  logdir='',
                  name="frozen_graph.pbtxt",
                  as_text=True)

In a prior test I did freeze it but to show you the errors now it didn't work (colab keeps crashing when calling convert_variables_to_constants_v2 function). The only thing that I changed was the number of classes.

I will try later because I think there is an issue with google colab today (crashes + very slow inference)

UcefMountacer commented 3 years ago

@gevero this is what I get by using tf.saved_model.save() to save the model :

WARNING:tensorflow:Skipping full serialization of Keras layer <layers.MaxPoolWithArgmax2D object at 0x7fa96d051128>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.convolutional.Conv2D object at 0x7fa96d051320>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <layers.MaxPoolWithArgmax2D object at 0x7fa96d069198>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.convolutional.Conv2D object at 0x7fa96d069390>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <layers.MaxPoolWithArgmax2D object at 0x7fa96d07d208>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.convolutional.Conv2D object at 0x7fa96d07d400>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <layers.MaxPoolWithArgmax2D object at 0x7fa960435278>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.convolutional.Conv2D object at 0x7fa960435470>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <layers.MaxPoolWithArgmax2D object at 0x7fa960467358>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.convolutional.Conv2D object at 0x7fa960467550>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <layers.MaxPoolWithArgmax2D object at 0x7fa9603fe358>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.convolutional.Conv2D object at 0x7fa9603fe550>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <layers.MaxPoolWithArgmax2D object at 0x7fa960412630>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.convolutional.Conv2D object at 0x7fa960412828>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <layers.MaxPoolWithArgmax2D object at 0x7fa9603ad630>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.convolutional.Conv2D object at 0x7fa9603ad828>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <layers.MaxPoolWithArgmax2D object at 0x7fa9603c16a0>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.convolutional.Conv2D object at 0x7fa9603c1898>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <layers.MaxPoolWithArgmax2D object at 0x7fa96d1e4828>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.convolutional.Conv2D object at 0x7fa96d1d53c8>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <layers.MaxPoolWithArgmax2D object at 0x7fa96f67cc88>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.convolutional.Conv2D object at 0x7fa96f67ca58>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <layers.MaxPoolWithArgmax2D object at 0x7fa96fcf4518>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.convolutional.Conv2D object at 0x7fa971572ef0>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <layers.MaxPoolWithArgmax2D object at 0x7fa9603fe518>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.convolutional.Conv2D object at 0x7fa9603fe7b8>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <layers.MaxPoolWithArgmax2D object at 0x7fa96039bb38>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.convolutional.Conv2D object at 0x7fa96039bd30>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <layers.MaxPoolWithArgmax2D object at 0x7fa960333e10>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.convolutional.Conv2D object at 0x7fa96033e048>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <layers.MaxPoolWithArgmax2D object at 0x7fa960349e10>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.convolutional.Conv2D object at 0x7fa960353048>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <layers.MaxPoolWithArgmax2D object at 0x7fa960360e80>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.convolutional.Conv2D object at 0x7fa9602ed0b8>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <layers.MaxPoolWithArgmax2D object at 0x7fa9602f9e80>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.convolutional.Conv2D object at 0x7fa9603050b8>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <layers.MaxPoolWithArgmax2D object at 0x7fa960319198>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.convolutional.Conv2D object at 0x7fa960319390>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <layers.MaxPoolWithArgmax2D object at 0x7fa9602b0198>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.convolutional.Conv2D object at 0x7fa9602b0390>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.convolutional.Conv2D object at 0x7fa9602b0908>, because it is not built.
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-36-1978bbbaf5f1> in <module>()
----> 1 tf.saved_model.save(Enet , '/content/')

31 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/saved_model/save.py in save(obj, export_dir, signatures, options)
    974 
    975   _, exported_graph, object_saver, asset_info = _build_meta_graph(
--> 976       obj, export_dir, signatures, options, meta_graph_def)
    977   saved_model.saved_model_schema_version = constants.SAVED_MODEL_SCHEMA_VERSION
    978 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/saved_model/save.py in _build_meta_graph(obj, export_dir, signatures, options, meta_graph_def)
   1045   if signatures is None:
   1046     signatures = signature_serialization.find_function_to_export(
-> 1047         checkpoint_graph_view)
   1048 
   1049   signatures, wrapped_functions = (

/usr/local/lib/python3.6/dist-packages/tensorflow/python/saved_model/signature_serialization.py in find_function_to_export(saveable_view)
     73   # If the user did not specify signatures, check the root object for a function
     74   # that can be made into a signature.
---> 75   functions = saveable_view.list_functions(saveable_view.root)
     76   signature = functions.get(DEFAULT_SIGNATURE_ATTR, None)
     77   if signature is not None:

/usr/local/lib/python3.6/dist-packages/tensorflow/python/saved_model/save.py in list_functions(self, obj, extra_functions)
    143     if obj_functions is None:
    144       obj_functions = obj._list_functions_for_serialization(  # pylint: disable=protected-access
--> 145           self._serialization_cache)
    146       self._functions[obj] = obj_functions
    147     if extra_functions:

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py in _list_functions_for_serialization(self, serialization_cache)
   2588     self.predict_function = None
   2589     functions = super(
-> 2590         Model, self)._list_functions_for_serialization(serialization_cache)
   2591     self.train_function = train_function
   2592     self.test_function = test_function

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py in _list_functions_for_serialization(self, serialization_cache)
   3017   def _list_functions_for_serialization(self, serialization_cache):
   3018     return (self._trackable_saved_model_saver
-> 3019             .list_functions_for_serialization(serialization_cache))
   3020 
   3021   def __getstate__(self):

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/base_serialization.py in list_functions_for_serialization(self, serialization_cache)
     85         `ConcreteFunction`.
     86     """
---> 87     fns = self.functions_to_serialize(serialization_cache)
     88 
     89     # The parent AutoTrackable class saves all user-defined tf.functions, and

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py in functions_to_serialize(self, serialization_cache)
     77   def functions_to_serialize(self, serialization_cache):
     78     return (self._get_serialized_attributes(
---> 79         serialization_cache).functions_to_serialize)
     80 
     81   def _get_serialized_attributes(self, serialization_cache):

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py in _get_serialized_attributes(self, serialization_cache)
     93 
     94     object_dict, function_dict = self._get_serialized_attributes_internal(
---> 95         serialization_cache)
     96 
     97     serialized_attr.set_and_validate_objects(object_dict)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/model_serialization.py in _get_serialized_attributes_internal(self, serialization_cache)
     55     objects, functions = (
     56         super(ModelSavedModelSaver, self)._get_serialized_attributes_internal(
---> 57             serialization_cache))
     58     functions['_default_save_signature'] = default_signature
     59     return objects, functions

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py in _get_serialized_attributes_internal(self, serialization_cache)
    102     """Returns dictionary of serialized attributes."""
    103     objects = save_impl.wrap_layer_objects(self.obj, serialization_cache)
--> 104     functions = save_impl.wrap_layer_functions(self.obj, serialization_cache)
    105     # Attribute validator requires that the default save signature is added to
    106     # function dict, even if the value is None.

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/save_impl.py in wrap_layer_functions(layer, serialization_cache)
    153   # Reset the losses of the layer and its children. The call function in each
    154   # child layer is replaced with tf.functions.
--> 155   original_fns = _replace_child_layer_functions(layer, serialization_cache)
    156   original_losses = _reset_layer_losses(layer)
    157 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/save_impl.py in _replace_child_layer_functions(layer, serialization_cache)
    272       serialized_functions = (
    273           child_layer._trackable_saved_model_saver._get_serialized_attributes(
--> 274               serialization_cache).functions)
    275     else:
    276       serialized_functions = (

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py in _get_serialized_attributes(self, serialization_cache)
     93 
     94     object_dict, function_dict = self._get_serialized_attributes_internal(
---> 95         serialization_cache)
     96 
     97     serialized_attr.set_and_validate_objects(object_dict)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/model_serialization.py in _get_serialized_attributes_internal(self, serialization_cache)
     55     objects, functions = (
     56         super(ModelSavedModelSaver, self)._get_serialized_attributes_internal(
---> 57             serialization_cache))
     58     functions['_default_save_signature'] = default_signature
     59     return objects, functions

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py in _get_serialized_attributes_internal(self, serialization_cache)
    102     """Returns dictionary of serialized attributes."""
    103     objects = save_impl.wrap_layer_objects(self.obj, serialization_cache)
--> 104     functions = save_impl.wrap_layer_functions(self.obj, serialization_cache)
    105     # Attribute validator requires that the default save signature is added to
    106     # function dict, even if the value is None.

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/save_impl.py in wrap_layer_functions(layer, serialization_cache)
    163   call_fn_with_losses = call_collection.add_function(
    164       _wrap_call_and_conditional_losses(layer),
--> 165       '{}_layer_call_and_return_conditional_losses'.format(layer.name))
    166   call_fn = call_collection.add_function(
    167       _extract_outputs_from_fn(layer, call_fn_with_losses),

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/save_impl.py in add_function(self, call_fn, name)
    503       # Manually add traces for layers that have keyword arguments and have
    504       # a fully defined input signature.
--> 505       self.add_trace(*self._input_signature)
    506     return fn
    507 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/save_impl.py in add_trace(self, *args, **kwargs)
    418             fn.get_concrete_function(*args, **kwargs)
    419 
--> 420         trace_with_training(True)
    421         trace_with_training(False)
    422       else:

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/save_impl.py in trace_with_training(value, fn)
    416           utils.set_training_arg(value, self._training_arg_index, args, kwargs)
    417           with K.deprecated_internal_learning_phase_scope(value):
--> 418             fn.get_concrete_function(*args, **kwargs)
    419 
    420         trace_with_training(True)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/save_impl.py in get_concrete_function(self, *args, **kwargs)
    547     if not self.call_collection.tracing:
    548       self.call_collection.add_trace(*args, **kwargs)
--> 549     return super(LayerCall, self).get_concrete_function(*args, **kwargs)
    550 
    551 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in get_concrete_function(self, *args, **kwargs)
   1165       ValueError: if this object has not yet been called on concrete values.
   1166     """
-> 1167     concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
   1168     concrete._garbage_collector.release()  # pylint: disable=protected-access
   1169     return concrete

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in _get_concrete_function_garbage_collected(self, *args, **kwargs)
   1071       if self._stateful_fn is None:
   1072         initializers = []
-> 1073         self._initialize(args, kwargs, add_initializers_to=initializers)
   1074         self._initialize_uninitialized_variables(initializers)
   1075 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
    695     self._concrete_stateful_fn = (
    696         self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
--> 697             *args, **kwds))
    698 
    699     def invalid_creator_scope(*unused_args, **unused_kwds):

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
   2853       args, kwargs = None, None
   2854     with self._lock:
-> 2855       graph_function, _, _ = self._maybe_define_function(args, kwargs)
   2856     return graph_function
   2857 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   3211 
   3212       self._function_cache.missed.add(call_context_key)
-> 3213       graph_function = self._create_graph_function(args, kwargs)
   3214       self._function_cache.primary[cache_key] = graph_function
   3215       return graph_function, args, kwargs

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   3073             arg_names=arg_names,
   3074             override_flat_arg_shapes=override_flat_arg_shapes,
-> 3075             capture_by_value=self._capture_by_value),
   3076         self._function_attributes,
   3077         function_spec=self.function_spec,

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
    984         _, original_func = tf_decorator.unwrap(python_func)
    985 
--> 986       func_outputs = python_func(*func_args, **func_kwargs)
    987 
    988       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
    598         # __wrapped__ allows AutoGraph to swap in a converted function. We give
    599         # the function a weak reference to itself to avoid a reference cycle.
--> 600         return weak_wrapped_fn().__wrapped__(*args, **kwds)
    601     weak_wrapped_fn = weakref.ref(wrapped_fn)
    602 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/save_impl.py in wrapper(*args, **kwargs)
    524         saving=True):
    525       with ops.enable_auto_cast_variables(layer._compute_dtype_object):  # pylint: disable=protected-access
--> 526         ret = method(*args, **kwargs)
    527     _restore_layer_losses(original_losses)
    528     return ret

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/save_impl.py in wrap_with_training_arg(*args, **kwargs)
    484         kwargs = kwargs.copy()
    485         utils.remove_training_arg(self._training_arg_index, args, kwargs)
--> 486         return call_fn(*args, **kwargs)
    487 
    488       return tf_decorator.make_decorator(

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/save_impl.py in call_and_return_conditional_losses(inputs, *args, **kwargs)
    567   def call_and_return_conditional_losses(inputs, *args, **kwargs):
    568     """Returns layer (call_output, conditional losses) tuple."""
--> 569     call_output = layer_call(inputs, *args, **kwargs)
    570     if version_utils.is_v1_layer_or_model(layer):
    571       conditional_losses = layer.get_losses_for(inputs)

TypeError: call() missing 1 required positional argument: 'upsample_layer'

When using tf.keras.models.save_model(Enet,'/content/') I get this :

ValueError: Model <models.EnetModel object at 0x7fa89918aef0> cannot be saved because the input shapes have not been set. Usually, input shapes are automatically determined from calling.fit()or.predict(). To manually set the shapes, callmodel.build(input_shape).

gevero commented 3 years ago

The problem is that the model, until fit is called, has no implicit input shape. So you cannot export it until the shape has been assigned. My guess is that first you have to build the model with model.build((n_batch,img_h,img_w,n_channels)) and then export it. It almost looks like you are trying to export an untrained model, or a model that has not been built.

Best

Giovanni

UcefMountacer commented 3 years ago

I tried the build but I did it withthis input : Tensor(1,360,480,3) ! I think it's the same as (n_batch,img_h,img_w,n_channels) 🤔.

It almost looks like you are trying to export an untrained model, or a model that has not been built.

No. I tried doing it after the training and after importing it in .tf format with no success. I guess that the problem comes from the fact that the layers and submodels are custom and that the get_config() method should be overwritten in all these classes (you did it for custom layers only). I tried overwriting it in the submodels but it didn't work. I guess I did something wrong since it's the first time I am faced with this problem.

Here you fin the method I coded for the custome models and the Enet model two (if that will help 😄 ).

Best

get_config_for_submodels_and_enet.txt