Closed lanluo-nvidia closed 4 weeks ago
Here's what I think could be a simpler way of doing this
1) We probably don't have to store output_shapes in TorchTensorRTModule
class. Once the compilation is finished, verify if the nodes of the TRT graph modules have metadata in them (if not we can update it by node.meta["val"] = original metadata)
Reference: https://github.com/pytorch/TensorRT/blob/3eb48d786d403b12bd3700004c60e08c5c002f7b/py/torch_tensorrt/dynamo/_compiler.py#L496-L499
Here the node corresponding to _run_on_acc0
can be queried as
trt_module_node = [node for node in gm.graph.nodes if node.name == "_run_on_acc0"]
trt_module_node.meta["val"] - This should already have fake tensors which need to be used in the exporter.
2) exporter We have the TRT module node here : https://github.com/pytorch/TensorRT/blob/3eb48d786d403b12bd3700004c60e08c5c002f7b/py/torch_tensorrt/dynamo/_exporter.py#L364 We could directly set (ensuring trt_module_node.meta["val"] always exists)
trt_node.meta["val"] = trt_module_node.meta["val"]
3) infer_module_types We can replace the dummy inference with graph inspection by reading output metadata. The output of this function could be a list of FakeTensors and we can extract the dtypes from this to pass it to TRTInterpreter.
Replacing the dummy inference will also need changes to our converter test suite.
Description
There is two changes introduced in this PR: 1) during the compile stage: skipped dummy inference and use graph inspection instead to get the output_node.meta['val']
2) during the save stage: skipped run_shape_analysis and use graph inspection instead to get the output_node.meta['val']
Fixes # (issue)
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: