pytorch / TensorRT

PyTorch/TorchScript/FX compiler for NVIDIA GPUs using TensorRT
https://pytorch.org/TensorRT
BSD 3-Clause "New" or "Revised" License
2.6k stars 351 forks source link

🐛 [Bug] [weight-stripped engine] doesn't work for `TorchTensorRTModule` #3217

Closed zewenli98 closed 1 month ago

zewenli98 commented 1 month ago

Bug Description

The PR https://github.com/pytorch/TensorRT/pull/3167 is supporting weight-stripped engines, which works for PythonTorchTensorRTModule but not for TorchTensorRTModule.

I observed the issue in the test: https://github.com/pytorch/TensorRT/blob/76bdf5e0f0e5e0e31d5bc4cbf1bedfa5f4f4ea32/tests/py/dynamo/models/test_weight_stripped_engine.py#L487-L523

The CI test reports the error:

FAILED models/test_weight_stripped_engine.py::TestWeightStrippedEngine::test_two_TRTRuntime_in_refitting - AssertionError: False is not true : TorchTensorRTModule outputs don't match with the original model. Cosine sim score: 0.0 Threshold: 0.99

I output refitted_output while using TorchTensorRTModule, which is all zeros, so it seems like the refitting was not successful.

zewenli98 commented 1 month ago

This is due to not clear EXCLUDE_WEIGHTS flag while serializing in refitting. Will be fixed in #3167