jeng1220 / KerasToTensorRT

This is a simple demonstration for running Keras model model on Tensorflow with TensorRT integration(TFTRT) or on TensorRT directly without invoking "freeze_graph.py".
67 stars 23 forks source link

Keras to TensorRT Examples

This is a simple demonstration for running Keras model model on Tensorflow with TensorRT integration(TFTRT) or on TensorRT directly without invoking "freeze_graph.py".

Note: Recommend that use NVIDIA Tensorflow docker image to run these examples. You can download the images from NVIDIA NGC.

Requirement

if you want to run model on TensorRT directly, Pycuda is also needed:

Examples

tftrt_example.py demonstrates how to run Keras model on TFTRT. This approach supports both NCHW and NHWC format because Tensorflow can handles format issue.

$ python tftrt_example.py

tftrt_resnet_example.py demonstrates how to run Keras Applications ResNet50 on TFTRT.

$ python tftrt_resnet_example.py

tftrt_multi_inputs_mutli_outputs_example.py demonstrates how to run a multi-input/output Keras model on TFTRT.

$ python tftrt_multi_inputs_mutli_outputs_example.py

trt_example.py demonstrates how to run Keras model on TensorRT which can achieve fastest speed. Because TensorRT didn't fully support NHWC yet, this approach only suits NCHW format.

$ python trt_example.py

Appendix

get_mnist_model.py can generate needed Keras models with two different input formats, one for NCHW foramt, another one for NHWC format.

Note: the needed models were already provided in repo.

$ python get_mnist_model.py -h # shows help message
$ python get_mnist_model.py -f 0 # generates model for NCHW format
$ python get_mnist_model.py -f 1 # generates model for NHWC format

Reference