pytorch / TensorRT

PyTorch/TorchScript/FX compiler for NVIDIA GPUs using TensorRT
https://pytorch.org/TensorRT
BSD 3-Clause "New" or "Revised" License
2.56k stars 350 forks source link

✨[Feature] A hook function after each graph change or partition #3233

Open sean-xiang-applovin opened 3 days ago

sean-xiang-applovin commented 3 days ago

Is your feature request related to a problem? Please describe. I am trying to compile our model with tensorrt these days with exported program, our model is not very big, the original graph contains like 2k~ of nodes.

I am running into different problems, and there are problems even before we start to convert the partitioned graph, like during decomposition or post lowering phrase. Sometimes the debug log indicates some error of some node, and I cannot find that node in my original graph.

I find it helpful to dump the graph after each "change". Even though there is log printed, but

  1. not all passes are printing the logs
  2. it is hard to scroll back the log history and copy the whole graph of each pass

Describe the solution you'd like A hook function to be called after each intermediate graph, like graph after each pass, and each partitioned graph.

The function takes a gm object, and some metadata of this gm object, like a name, or a number k indicating it's the kth change to the original graph.

The most common implementation of this hook function is probably, to print the graph, or save it somewhere on the disk. So that I can check each graph and find which is the first graph contains the error node described in the error log.

I can also see this can be helpful if people want to visualize each graph and graph change during compilation.

Describe alternatives you've considered

Additional context

narendasan commented 2 days ago

@peri044 did we implement LOG_GRAPH for dynamo?

narendasan commented 1 day ago

@sean-xiang-applovin We use a PassManager to manage the lowering. https://github.com/pytorch/TensorRT/blob/3110d31db850016736de4dfaaec9004bcfbf4c70/py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py#L7

There are two sets of passes: ATEN_PRE_LOWERING_PASSES and ATEN_POST_LOWERING_PASSES

You can inject custom passes that can do whatever (print out graphs, dump or save graphs etc.) and place them at a particular index in the pass pipeline. Its not a particularly ergonomic API but if you want to submit improvements you find useful in a PR we would be interested

sean-xiang-applovin commented 15 hours ago

@narendasan Sure, I think that pass manager could help with the graph before partition, what about I want to access the partitioned graph?

narendasan commented 14 hours ago

We do have the dry_run system which can provide insight into post partitioning structure but I think we would be unlikely to add a callback or something to save partitioned but not compiled graphs

sean-xiang-applovin commented 12 hours ago

Got it, thanks for your reply @narendasan