Inference results [FP16] for both Llama and SDXL models in Torch-TensorRT's torch.compile backend have accuracy discrepancies relative to the Torch counterpart model
Specifically, the inference results from Llama and SDXL vary from the Torch inference results when provided the same seed. This behavior is not reproduced on SD plain or smaller transformer-based models
One of the few complex layers which SDXL's UNet and Llama share, is torch.ops.aten._scaled_dot_product_efficient_attention.default. We begin with this as a potential source of error.
After testing a multitude of different configurations to isolate the issue, we observed the following code block produced a large margin of error when TRT is compared against Torch and Numpy:
The above was then narrowed to the following simple matrix-multiply, which when run in FP16 with the dimensions (8192 x 640), (640 x 640), as is used in our SDXL configuration, produces a maximum difference of 10 between two elements in the output of TRT vs that of Torch. The mean difference was also high, at around 0.5.
class TestModule(torch.nn.Module):
def forward(self, q, k):
return (q@k)
Additionally, the native_layer_norm operator may be contributing to the error, since its exclusion brings improved accuracy as well. This is also under investigation.
We have further narrowed the matmul cases for easier example-reproducing
Next Steps
See if issues persist when using FP32 precision, again narrow down the cases to identify layers which could be resulting in the accuracy issues if they continue
Analysis Findings
torch.compile
backend have accuracy discrepancies relative to the Torch counterpart modeltorch.ops.aten._scaled_dot_product_efficient_attention.default
. We begin with this as a potential source of error.The above was then narrowed to the following simple matrix-multiply, which when run in FP16 with the dimensions
(8192 x 640), (640 x 640)
, as is used in our SDXL configuration, produces a maximum difference of 10 between two elements in the output of TRT vs that of Torch. The mean difference was also high, at around 0.5.Additionally, the
native_layer_norm
operator may be contributing to the error, since its exclusion brings improved accuracy as well. This is also under investigation.