pytorch / TensorRT

PyTorch/TorchScript/FX compiler for NVIDIA GPUs using TensorRT
https://pytorch.org/TensorRT
BSD 3-Clause "New" or "Revised" License
2.6k stars 351 forks source link

🐛 [Bug] Const indices failed with embedding bag #3263

Open sean-xiang-applovin opened 4 weeks ago

sean-xiang-applovin commented 4 weeks ago

Bug Description

indices are const tensor, which gets const folded into frozen param. The meta of the frozen param node is empty dict, leading to converter validation check failure here

making torch.ops.aten._embedding_bag.default unsupported op, and compile failure.

To Reproduce

import torch
import torch_tensorrt
import tensorrt
from torch_tensorrt.dynamo._compiler import compile as dynamo_compile
print(torch.__version__) # 2.5.0+cu124
print(torch_tensorrt.__version__) # 2.5.0
print(tensorrt.__version__) # 10.3.0

class ToyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.embedding_bag_module = torch.nn.EmbeddingBag(100, 32, mode='sum')
        self.register_buffer("index_tensor", torch.tensor([x for x in range(100)], dtype=torch.long))

    def forward(self, tensor: torch.Tensor) -> torch.Tensor:
        out = self.embedding_bag_module(self.index_tensor.broadcast_to(tensor.shape), per_sample_weights=tensor)
        return out

error_model_input = (torch.randn(20, 100, dtype=torch.float32), )
error_model = ToyModel()
error_model_eval = error_model.eval()
with torch.no_grad():
    ep = torch.export.export(error_model_eval, args=error_model_input)
    compiled = dynamo_compile(
        exported_program=ep,
        disable_tf32=True,
        inputs=error_model_input,
        min_block_size=1,
        debug=True,
    )

Expected behavior

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

Additional context

zewenli98 commented 3 weeks ago

Due to some limitations, we need meta to deal with data-dependent issue. In addition, we currently only support 1D indices/input. If this doesn't work for you, I think you have to fall back this op to pytorch for now.