pytorch / TensorRT

PyTorch/TorchScript/FX compiler for NVIDIA GPUs using TensorRT
https://pytorch.org/TensorRT
BSD 3-Clause "New" or "Revised" License
2.58k stars 350 forks source link

🐛 [Bug] view_to_reshape metadata mismatch #3221

Open sean-xiang-applovin opened 4 weeks ago

sean-xiang-applovin commented 4 weeks ago

Bug Description

When replacing the view nodes with reshape nodes, the metadata of the original view nodes, are assigned to the reshape nodes in the wrong order.

For example, assume we have 2 view nodes, view_1, view_2. So we will have a list, [metadata1, metadata2], from this code

I haven't dive deep yet, but after torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement), in the new graph, the view nodes are replaced with reshape nodes, however, the order of reshape nodes can be reshape_default_2, reshape_default_1. In this case, when we set the metadata back by set_metadata(gm, replacement_op, metadata), the metadata can mismatch

To Reproduce

I am sorry, I cannot share the graph of the model, I will try my best to have e toy model to reproduce the error.

Steps to reproduce the behavior:

1. 2. 3.

Expected behavior

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

Additional context

sean-xiang-applovin commented 3 weeks ago

I check the source code, seems we have converter support for torch.ops.aten.view, so probably the best way to solve this is to delete the view_to_reshape pass