tensorflow / models

Models and examples built with TensorFlow
Other
77.16k stars 45.75k forks source link

Converting TF Model(frozen_graph) to TensorRT Engine #5300

Closed Akhtar303nu closed 4 years ago

Akhtar303nu commented 6 years ago

import keras import keras.backend as K import tensorflow as tf import uff

output_names = ['predictions/Softmax'] frozen_graph_filename = 'frozen_inference_graph.pb' sess = K.get_session()

freeze graph and remove training nodes

graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, output_names) graph_def = tf.graph_util.remove_training_nodes(graph_def)

write frozen graph to file

with open(frozen_graph_filename, 'wb') as f: f.write(graph_def.SerializeToString()) f.close()

convert frozen graph to uff

uff_model = uff.from_tensorflow_frozen_model(frozen_graph_filename, output_names) G_LOGGER = trt.infer.ConsoleLogger(trt.infer.LogSeverity.ERROR) parser = uffparser.create_uff_parser() parser.register_input("Placeholder", (1,28,28), 0) parser.register_output("fc2/Relu") engine = trt.utils.uff_to_trt_engine(G_LOGGER, uff_model, parser, 1, 1 << 20) parser.destroy() runtime = trt.infer.create_infer_runtime(G_LOGGER) context = engine.create_execution_context() output = np.empty(10, dtype = np.float32)

Alocate device memory

d_input = cuda.mem_alloc(1 img.nbytes) d_output = cuda.mem_alloc(1 output.nbytes)

bindings = [int(d_input), int(d_output)] stream = cuda.Stream()

Transfer input data to device

cuda.memcpy_htod_async(d_input, img, stream)

Execute model

context.enqueue(1, bindings, stream.handle, None)

Transfer predictions back

cuda.memcpy_dtoh_async(output, d_output, stream)

Syncronize threads

stream.synchronize() print("Test Case: " + str(label)) print ("Prediction: " + str(np.argmax(output))) trt.utils.write_engine_to_file("./tf_mnist.engine", engine.serialize())

list of pakages Linux:16 Cuda:9.0 tensorRt:4 Python 3.5 Gpu:Gtx 1080 Tensorflow:1.10.1

Error

Traceback (most recent call last): File "bbb.py", line 11, in graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, output_names) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/graph_util_impl.py", line 232, in convert_variables_to_constants inference_graph = extract_sub_graph(input_graph_def, output_node_names) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/graph_util_impl.py", line 174, in extract_sub_graph _assert_nodes_are_present(name_to_node, dest_nodes) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/graph_util_impl.py", line 133, in _assert_nodes_are_present assert d in name_to_node, "%s is not in graph" % d AssertionError: predictions/Softmax is not in graph

OR please suggest me any code which convert tensorflow frozen graph(frozen_inference_graph.pb) to trt engine for object detection task

tensorflowbutler commented 6 years ago

Thank you for your post. We noticed you have not filled out the following field in the issue template. Could you update them if they are relevant in your case, or leave them as N/A? Thanks. What is the top-level directory of the model you are using Have I written custom code OS Platform and Distribution TensorFlow installed from TensorFlow version Bazel version CUDA/cuDNN version GPU model and memory Exact command to reproduce

qlzh727 commented 6 years ago

Which model are you trying to convert? Please fill the information regarding to the comment from @tensorflowbutler

Akhtar303nu commented 6 years ago

Thank you for your response

I am using Google pretrained object detection model (ssd_mobilenet_v1_coco) I want to convert frozen_inference_graph.pb to trt graph Pakages and other information Linux:16 Tensorflow Installation using pip(pip install tensorflow-gpu) Tensorflow:3.5 Cuda:9.0, V9.0.176 Gpu:Gtx 1080 Memory:8118MiB

Can you share with any code Which perform this task (frozen grpah to trt_engine)?

On Mon, Sep 17, 2018 at 9:35 PM Qianli Scott Zhu notifications@github.com wrote:

Which model are you trying to convert? Please fill the information regarding to the comment from @tensorflowbutler https://github.com/tensorflowbutler

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/tensorflow/models/issues/5300#issuecomment-422083701, or mute the thread https://github.com/notifications/unsubscribe-auth/AoA9o1uiwRAnQEhh227ktAGCqz75iVQ3ks5ub89GgaJpZM4WmpWO .

qlzh727 commented 6 years ago

Thanks for the reply. Assigning to pkulze who works on object detection.

Akhtar303nu commented 6 years ago

Thanks I am waiting for your response

On Mon, Sep 17, 2018, 10:06 PM Qianli Scott Zhu notifications@github.com wrote:

Thanks for the reply. Assigning to pkulze who works on object detection.

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/tensorflow/models/issues/5300#issuecomment-422094130, or mute the thread https://github.com/notifications/unsubscribe-auth/AoA9oxnySdjlh9y87WgV7ZRqoY6FtB3uks5ub9aAgaJpZM4WmpWO .

pkulzc commented 6 years ago

Hmm I don't really know trt engine but looks like it needs the output node name. In this case 'predictions/Softmax' is not a valid node name in our frozen graph.

Our output nodes can be found here, and StackOverflow might be a better place for trt_engine questions.

Akhtar303nu commented 6 years ago

Thanks pkulzc Yes you are right in my graph also present these nodes " 'num_detections', 'detection_boxes', 'detection_scores',\n", " 'detection_classes', 'detection_masks'\n", but I try all these nodes name not work for me genrate the same error AssertionError Traceback (most recent call last)

in () 9 10 # freeze graph and remove training nodes ---> 11 graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, output_names) 12 graph_def = tf.graph_util.remove_training_nodes(graph_def) 13 /usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/graph_util_impl.py in convert_variables_to_constants(sess, input_graph_def, output_node_names, variable_names_whitelist, variable_names_blacklist) 230 # This graph only includes the nodes needed to evaluate the output nodes, and 231 # removes unneeded nodes like those involved in saving and assignment. --> 232 inference_graph = extract_sub_graph(input_graph_def, output_node_names) 233 234 found_variables = {} /usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/graph_util_impl.py in extract_sub_graph(graph_def, dest_nodes) 172 name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary( 173 graph_def) --> 174 _assert_nodes_are_present(name_to_node, dest_nodes) 175 176 nodes_to_keep = _bfs_for_reachable_nodes(dest_nodes, name_to_input_name) /usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/graph_util_impl.py in _assert_nodes_are_present(name_to_node, nodes) 131 """Assert that nodes are present in the graph.""" 132 for d in nodes: --> 133 assert d in name_to_node, "%s is not in graph" % d 134 135 AssertionError: detection_classes is not in graph you can see my frozen graph https://drive.google.com/file/d/1AGBpXzXVCABoLNnkOap-Q13EaLHLAeZ3/view?usp=sharing And the file from which I created frozen graph https://drive.google.com/file/d/16TpAB9XYn3PEufVlwRBAR2vjfgzBTAZf/view?usp=sharing Thanks
aejaex commented 5 years ago

Having the same issue, please resolve

lengerke commented 5 years ago

For anyone ending up here from a search engine: You should make sure the output name you are handing the function really is a correct name within the tf model. You can make sure by running

tf_node_list = [n.name for n in tf.get_default_graph().as_graph_def().node]

and looking for your node name in that complete list of nodes within your model.

Possibly the name is slightly different from what you expected.

JRMeyer commented 5 years ago

@lengerke -- I'm trying similar code, but I can't even import the uff_to_trt_engine function from tensorrt.utils, and I can't find any documentation for the function... any pointers?

tensorflowbutler commented 4 years ago

Hi There, We are checking to see if you still need help on this, as this seems to be considerably old issue. Please update this issue with the latest information, code snippet to reproduce your issue and error you are seeing. If we don't hear from you in the next 7 days, this issue will be closed automatically. If you don't need help on this issue any more, please consider closing this.