NVIDIA / TensorRT-LLM

TensorRT-LLM provides users with an easy-to-use Python API to define Large Language Models (LLMs) and build TensorRT engines that contain state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs. TensorRT-LLM also contains components to create Python and C++ runtimes that execute those TensorRT engines.
https://nvidia.github.io/TensorRT-LLM
Apache License 2.0
7.55k stars 823 forks source link

Converting the Large world model in TensorRT LLM #1174

Open zmy1116 opened 4 months ago

zmy1116 commented 4 months ago

Hello,

So we are interested into deploy the recently published large world model in triton server

https://largeworldmodel.github.io/ https://huggingface.co/LargeWorldModel/LWM-Chat-32K-Jax/tree/main

The model has the same architecture as LLAMA, so I assume all I need to do is to follow the LLAMA TensorRT guideline to convert the model with RoPE Scaling. Except the weights are in JAX....

I don't know JAX ... I know you have a JAX conversion code for Gemma. So I suppose it's really not a very difficult to convert_checkpoint for the LWM model... I guess my problem is that I'm not familiar with TensorRT LLM nor Jax to know what and how to proceed.

Sorry question not so clear, tbh I don't know enough to ask precise questions at this point

Thanks

nivibilla commented 4 months ago

They also have HF versions. Just hidden under the expand models button https://huggingface.co/LargeWorldModel/LWM-Text-32K

zmy1116 commented 4 months ago

They also have HF versions. Just hidden under the expand models button https://huggingface.co/LargeWorldModel/LWM-Text-32K

This is the text only model. The interesting ones are the ones with video input, which only has jax weights.