NVIDIA / TensorRT-LLM

TensorRT-LLM provides users with an easy-to-use Python API to define Large Language Models (LLMs) and build TensorRT engines that contain state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs. TensorRT-LLM also contains components to create Python and C++ runtimes that execute those TensorRT engines.
https://nvidia.github.io/TensorRT-LLM
Apache License 2.0
8.71k stars 996 forks source link

How to use Medusa to support encoder decoder model? #2387

Open TianzhongSong opened 3 weeks ago

TianzhongSong commented 3 weeks ago

TRT-LLM version: v0.11.0

I'm deploying a bart model with medusa heads, and i notice this issue https://github.com/NVIDIA/TensorRT-LLM/issues/1946, then i adapted my model with follow steps:

   1. Adapt bart training in Medusa to obtain training heads
   2. Modify models/medusa/model.py to support enc_dec model
   3. Modify models/enc_dec/model.py to support speculative decoding parameters

However, encountered the following error:

  File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/runtime/generation.py", line 930, in wrapper
    ret = func(self, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/runtime/generation.py", line 3377, in decode
    assert not self.cross_attention
AssertionError

Cross attention can not use medusa? Any idea?

hello-11 commented 3 weeks ago

Thank you for your question. Unfortunately, cross-attention doesn't support Medusa yet, and we don't plan to add this feature. However, if you can share more details about your use case, we would happily consider it further.