shamangary / FSA-Net

[CVPR19] FSA-Net: Learning Fine-Grained Structure Aggregation for Head Pose Estimation from a Single Image
Apache License 2.0
612 stars 155 forks source link

[TensorRT] Convert TF frozen graph in TF-TRT #38

Open nhmnhat1997 opened 5 years ago

nhmnhat1997 commented 5 years ago

Hi, I'm trying to convert your pretrained frozen graph to TF-TRT using this code:

import tensorflow as tf
from tensorflow.python.compiler.tensorrt import trt_convert as trt
with tf.Session() as sess:
    # First deserialize your frozen graph:
    with tf.gfile.GFile('./fsanet_capsule_3_16_2_21_5.pb', 'rb') as f:
        frozen_graph = tf.GraphDef()
        frozen_graph.ParseFromString(f.read())
    # Now you can create a TensorRT inference graph from your
    # frozen graph:
    converter = trt.TrtGraphConverter(
        input_graph_def=frozen_graph,
        nodes_blacklist=['logits', 'classes']) #output nodes
    trt_graph = converter.convert()
    # Import the TensorRT graph into a new graph and run:
    output_node = tf.import_graph_def(
        trt_graph,
        return_elements=['logits', 'classes'])
    sess.run(output_node)

The error message is:

InvalidArgumentError: Traceback (most recent call last)
<ipython-input-3-07f6839266b8> in <module>
      7     # frozen graph:
      8     converter = trt.TrtGraphConverter(input_graph_def=frozen_graph, nodes_blacklist=['classes']) #output nodes
----> 9     trt_graph = converter.convert()
     10     # Import the TensorRT graph into a new graph and run:
     11     output_node = tf.import_graph_def(

~/python_env/tensorflow_gpu/lib/python3.6/site-packages/tensorflow/python/compiler/tensorrt/trt_convert.py in convert(self)
    296     assert not self._converted
    297     if self._input_graph_def:
--> 298       self._convert_graph_def()
    299     else:
    300       self._convert_saved_model()

~/python_env/tensorflow_gpu/lib/python3.6/site-packages/tensorflow/python/compiler/tensorrt/trt_convert.py in _convert_graph_def(self)
    224     self._add_nodes_blacklist()
    225 
--> 226     self._run_conversion()
    227 
    228   def _collections_to_keep(self, collection_keys):

~/python_env/tensorflow_gpu/lib/python3.6/site-packages/tensorflow/python/compiler/tensorrt/trt_convert.py in _run_conversion(self)
    202         grappler_session_config,
    203         self._grappler_meta_graph_def,
--> 204         graph_id=b"tf_graph")
    205     self._converted = True
    206 

~/python_env/tensorflow_gpu/lib/python3.6/site-packages/tensorflow/python/grappler/tf_optimizer.py in OptimizeGraph(config_proto, metagraph, verbose, graph_id, cluster)
     39                                           config_proto.SerializeToString(),
     40                                           metagraph.SerializeToString(),
---> 41                                           verbose, graph_id)
     42   if ret_from_swig is None:
     43     return None

InvalidArgumentError: Failed to import metagraph, check error log for more info.

I also tried converting your model to UFF but it also didn't work. Do you have any experience with TensorRT or TF-TRT ? Can you help me with it ? Thank you so much!

abhinavatai commented 3 years ago

Hey, This error sometimes occurs when you are using wrong output nodes. Correct your output nodes i.e "node_blacklist".