tensorflow / tensorflow

An Open Source Machine Learning Framework for Everyone
https://tensorflow.org
Apache License 2.0
186.26k stars 74.3k forks source link

Unable to convert mT5 model to tflite (tensorflow.GraphDef exceeds maximum protobuf size of 2GB) #47326

Closed Arman-IMRSV closed 3 years ago

Arman-IMRSV commented 3 years ago

1. System information

2. Code

import transformers
from transformers import AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM, TFAutoModelForSeq2SeqLM
import tensorflow as tf

model_name = "google/mt5-base"
config = AutoConfig.from_pretrained(
    model_name
)
tokenizer = AutoTokenizer.from_pretrained(
    model_name
)
model = TFAutoModelForSeq2SeqLM.from_pretrained(
    model_name,
    from_pt=True,
    config=config
)

input_spec = tf.TensorSpec([1, 64], tf.int16)
model._set_inputs(input_spec, training=False)
print(model.inputs)
print(model.outputs)
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.experimental_new_converter = True
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.inference_input_type = tf.float32
converter.target_spec.supported_ops = [tf.lite.OpsSet.SELECT_TF_OPS]
tflite_model = converter.convert()
open("mt5_base_en_to_fa.tflite", "wb").write(tflite_model)

3. Error Message:

ValueError                                Traceback (most recent call last)
<ipython-input-78-e62147be515c> in <module>
----> 1 tflite_model = converter.convert()
      2 open("mt5_base_en_to_fa.tflite", "wb").write(tflite_model)

~/.pyenv/versions/3.7.2/lib/python3.7/site-packages/tensorflow/lite/python/lite.py in convert(self)
    807     frozen_func, graph_def = (
    808         _convert_to_constants.convert_variables_to_constants_v2_as_graph(
--> 809             self._funcs[0], lower_control_flow=False))
    810 
    811     input_tensors = [

~/.pyenv/versions/3.7.2/lib/python3.7/site-packages/tensorflow/python/framework/convert_to_constants.py in convert_variables_to_constants_v2_as_graph(func, lower_control_flow, aggressive_inlining)
   1107 
   1108   frozen_func = _construct_concrete_function(func, output_graph_def,
-> 1109                                              converted_input_indices)
   1110   return frozen_func, output_graph_def
   1111 

~/.pyenv/versions/3.7.2/lib/python3.7/site-packages/tensorflow/python/framework/convert_to_constants.py in _construct_concrete_function(func, output_graph_def, converted_input_indices)
    999   new_func = wrap_function.function_from_graph_def(output_graph_def,
   1000                                                    new_input_names,
-> 1001                                                    new_output_names)
   1002 
   1003   # Manually propagate shape for input tensors where the shape is not correctly

~/.pyenv/versions/3.7.2/lib/python3.7/site-packages/tensorflow/python/eager/wrap_function.py in function_from_graph_def(graph_def, inputs, outputs)
    648     importer.import_graph_def(graph_def, name="")
    649 
--> 650   wrapped_import = wrap_function(_imports_graph_def, [])
    651   import_graph = wrapped_import.graph
    652   return wrapped_import.prune(

~/.pyenv/versions/3.7.2/lib/python3.7/site-packages/tensorflow/python/eager/wrap_function.py in wrap_function(fn, signature, name)
    626           signature=signature,
    627           add_control_dependencies=False,
--> 628           collections={}),
    629       variable_holder=holder,
    630       signature=signature)

~/.pyenv/versions/3.7.2/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
    984         _, original_func = tf_decorator.unwrap(python_func)
    985 
--> 986       func_outputs = python_func(*func_args, **func_kwargs)
    987 
    988       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

~/.pyenv/versions/3.7.2/lib/python3.7/site-packages/tensorflow/python/eager/wrap_function.py in __call__(self, *args, **kwargs)
     85 
     86   def __call__(self, *args, **kwargs):
---> 87     return self.call_with_variable_creator_scope(self._fn)(*args, **kwargs)
     88 
     89   def call_with_variable_creator_scope(self, fn):

~/.pyenv/versions/3.7.2/lib/python3.7/site-packages/tensorflow/python/eager/wrap_function.py in wrapped(*args, **kwargs)
     91     def wrapped(*args, **kwargs):
     92       with variable_scope.variable_creator_scope(self.variable_creator_scope):
---> 93         return fn(*args, **kwargs)
     94 
     95     return wrapped

~/.pyenv/versions/3.7.2/lib/python3.7/site-packages/tensorflow/python/eager/wrap_function.py in _imports_graph_def()
    646 
    647   def _imports_graph_def():
--> 648     importer.import_graph_def(graph_def, name="")
    649 
    650   wrapped_import = wrap_function(_imports_graph_def, [])

~/.pyenv/versions/3.7.2/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
    505                 'in a future version' if date is None else ('after %s' % date),
    506                 instructions)
--> 507       return func(*args, **kwargs)
    508 
    509     doc = _add_deprecated_arg_notice_to_docstring(

~/.pyenv/versions/3.7.2/lib/python3.7/site-packages/tensorflow/python/framework/importer.py in import_graph_def(***failed resolving arguments***)
    403       return_elements=return_elements,
    404       name=name,
--> 405       producer_op_list=producer_op_list)
    406 
    407 

~/.pyenv/versions/3.7.2/lib/python3.7/site-packages/tensorflow/python/framework/importer.py in _import_graph_def_internal(graph_def, input_map, return_elements, validate_colocation_constraints, name, producer_op_list)
    492   # _ProcessNewOps.
    493   with graph._mutation_lock():  # pylint: disable=protected-access
--> 494     with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized:
    495       try:
    496         results = c_api.TF_GraphImportGraphDefWithResults(

ValueError: Message tensorflow.GraphDef exceeds maximum protobuf size of 2GB: 2330463933

4. Note:

  1. I realized that the suuport for models larger than 2GB is added to ONNX converter (here ). I was wondering if any similar fix is added for tflite converter as well? If not, how can I overcome this error?
  2. I have been to convert "google/mt5-small" model to tflite, but not the mt5-base as it is a larger model.

Thanks.

abattery commented 3 years ago

Hi @Arman-IMRSV

Is it possible to follow this workaround in https://github.com/tensorflow/tensorflow/issues/45041#issuecomment-731268801 ?

Arman-IMRSV commented 3 years ago

@abattery I do not believe if it is possible to apply that solution here, as I am loading a pretrained model. Do you have any idea how I can apply that to this?

Arman-IMRSV commented 3 years ago

@abattery Do you have any suggested solution? I also had a look at converting ONNX to TFLite. It did not work either.

abattery commented 3 years ago

Do you think it is possible to follow the above suggestion in the conversion code from ONNX to TF?

Arman-IMRSV commented 3 years ago

@abattery I do not believe so. Again, in ONNX, we have the frozen graph. I do not think if we can apply that solution to this. Is there any other way around?

Arman-IMRSV commented 3 years ago

@TomWildenhain-Microsoft Is there a way to add the support for large models to TFLite conversion, similar to what you did for ONNX conversion in this PR?

Arman-IMRSV commented 3 years ago

@abattery Could you please let me know who I should request help from? Who has been in TFlite implementation to help, please?

TomWildenhain-Microsoft commented 3 years ago

The way we did this with the onnx converter is a bit of a hack, but it would probably work here. Just curious, why do you want to convert this model to tflite? Tflite is designed to run on low-power devices so I'm surprised you'd want to run such a huge model on it.

Arman-IMRSV commented 3 years ago

@TomWildenhain-Microsoft I intend to use mt5-base one mobile device for multi-lingual translation task. I know it's pretty big, but I believe I should be able to fit it on a mobile device after compression.

So you have any idea how to apply that hack here as well?

abattery commented 3 years ago

The ONNX-TFLite converter actually creates the corresponding TensorFlow graph from the given original model. Could you file a feature request towards the ONNX-TFLite converter in order to pull out of weights separately and use a saved model format for the weight serialization instead of inlining weights in the operator definitions in a single protobuf file like https://github.com/tensorflow/tensorflow/issues/45041#issuecomment-731268801?

google-ml-butler[bot] commented 3 years ago

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you.

google-ml-butler[bot] commented 3 years ago

Closing as stale. Please reopen if you'd like to work on this further.

google-ml-butler[bot] commented 3 years ago

Are you satisfied with the resolution of your issue? Yes No