NVIDIA / TensorRT-LLM

TensorRT-LLM provides users with an easy-to-use Python API to define Large Language Models (LLMs) and build TensorRT engines that contain state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs. TensorRT-LLM also contains components to create Python and C++ runtimes that execute those TensorRT engines.
https://nvidia.github.io/TensorRT-LLM
Apache License 2.0
8.11k stars 896 forks source link

Why is llava trt-llm not much faster than transformers? #1123

Open bleedingfight opened 6 months ago

bleedingfight commented 6 months ago

System Info

Who can help?

No response

Information

Tasks

Reproduction

  1. bash run_demo.sh

Expected behavior

output not much faster then transformers!

actual behavior

[02/21/2024-11:36:57] [TRT-LLM] [I] TensorRT vision encoder latency: 0.0076348876953125 sec [02/21/2024-11:36:57] [TRT-LLM] [I] TensorRT-LLM LLM latency: 4.899357304573059 sec [02/21/2024-11:36:57] [TRT-LLM] [I] Generate latency: 4.9331824231147765 sec

additional notes

I only modified a part of the code that is not very related to the algorithm,code llava.py:

import argparse
import os
import sys
from pathlib import Path

import numpy as np
import requests
import tensorrt as trt
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from transformers import (AutoConfig, AutoProcessor, AutoTokenizer,
                          Blip2Processor, NougatProcessor, NougatTokenizerFast)

import tensorrt_llm
import tensorrt_llm.profiler as profiler
from tensorrt_llm import logger
from tensorrt_llm._utils import torch_to_numpy
from tensorrt_llm.runtime import ModelRunner, Session, TensorInfo

sys.path.append(str(Path(__file__).parent.parent))
from enc_dec.run import TRTLLMEncDecModel

def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument('--max_new_tokens', type=int, default=30)
    parser.add_argument('--batch_size', type=int, default=1)
    parser.add_argument('--log_level', type=str, default='info')
    parser.add_argument('--visual_engine_dir',
                        type=str,
                        default=None,
                        help='Directory containing visual TRT engines')
    parser.add_argument('--llm_engine_dir',
                        type=str,
                        default=None,
                        help='Directory containing TRT-LLM engines')
    parser.add_argument('--hf_model_dir',
                        type=str,
                        default=None,
                        help="Directory containing tokenizer")
    parser.add_argument(
        '--decoder_llm',
        action='store_true',
        help='Whether LLM is decoder-only or an encoder-decoder variant?')
    parser.add_argument('--input_text',
                        type=str,
                        default='Question: which city is this? Answer:',
                        help='Text prompt to LLM')
    parser.add_argument('--num_beams',
                        type=int,
                        help="Use beam search if num_beams >1",
                        default=1)
    parser.add_argument('--top_k', type=int, default=1)

    return parser.parse_args()

def trt_dtype_to_torch(dtype):
    if dtype == trt.float16:
        return torch.float16
    elif dtype == trt.float32:
        return torch.float32
    elif dtype == trt.int32:
        return torch.int32
    else:
        raise TypeError("%s is not supported" % dtype)

class MultiModalModel:

    def __init__(self, args):
        self.args = args

        runtime_rank = tensorrt_llm.mpi_rank()
        device_id = runtime_rank % torch.cuda.device_count()
        torch.cuda.set_device(device_id)
        self.stream = torch.cuda.current_stream().cuda_stream

        self.init_image_encoder()
        self.init_tokenizer()
        self.init_llm()

    def init_tokenizer(self):
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.args.hf_model_dir, use_fast=False, use_legacy=False)

        self.tokenizer.padding_side = "right"
        self.tokenizer.pad_token = self.tokenizer.eos_token

    def init_image_encoder(self):
        vit_path = os.path.join(self.args.visual_engine_dir,
                                'visual_encoder_fp16.engine')
        logger.info(f'Loading engine from {vit_path}')
        with open(vit_path, 'rb') as f:
            engine_buffer = f.read()
        logger.info(f'Creating session from engine {vit_path}')
        self.visual_encoder_session = Session.from_serialized_engine(
            engine_buffer)

    def init_llm(self):
        if self.args.decoder_llm:
            self.model = ModelRunner.from_dir(self.args.llm_engine_dir,
                                              rank=tensorrt_llm.mpi_rank(),
                                              debug_mode=False)
            self.model_config = self.model.session._model_config
        else:
            self.model = TRTLLMEncDecModel.from_engine(
                self.args.hf_model_dir.split('/')[-1],
                self.args.llm_engine_dir,
                skip_encoder=self.args.nougat,
                debug_mode=False)

            if args.nougat:
                self.model_config = self.model.decoder_model_config
                self.runtime_mapping = self.model.decoder_runtime_mapping
            else:
                self.model_config = self.model.encoder_model_config
                self.runtime_mapping = self.model.encoder_runtime_mapping

            config = AutoConfig.from_pretrained(self.args.hf_model_dir)
            decoder_start_id = config.decoder_start_token_id
            if decoder_start_id is None:
                decoder_start_id = self.tokenizer.bos_token_id

            decoder_input_ids = torch.IntTensor([[decoder_start_id]]).to("cuda")
            batch_size = self.args.batch_size
            self.decoder_input_ids = decoder_input_ids.repeat((batch_size, 1))

    def generate(self, pre_prompt, post_prompt, image, max_new_tokens):
        profiler.start("Generate")
        profiler.start("Vision")
        visual_features, visual_atts = self.get_visual_features(image)
        profiler.stop("Vision")

        pre_input_ids = self.tokenizer(pre_prompt,
                                       return_tensors="pt",
                                       padding=True).input_ids.to("cuda")
        if post_prompt is not None:
            post_input_ids = self.tokenizer(post_prompt,
                                            return_tensors="pt",
                                            padding=True).input_ids.to("cuda")
            length = pre_input_ids.shape[1] + post_input_ids.shape[
                1] + visual_atts.shape[1]
        else:
            post_input_ids = None
            length = pre_input_ids.shape[1] + visual_atts.shape[1]

        input_lengths = torch.IntTensor([length]).to(torch.int32).to("cuda")
        input_ids, ptuning_args = self.setup_fake_prompts(
            visual_features, pre_input_ids, post_input_ids, input_lengths)

        if self.args.decoder_llm:
            prompt_table = ptuning_args[0]
            prompt_table = torch.stack([prompt_table])
            np.save('prompt_table.npy', torch_to_numpy(prompt_table))

        profiler.start("LLM")
        if self.args.decoder_llm:
            end_id = self.tokenizer.eos_token_id
            output_ids = self.model.generate(
                input_ids.to("cpu"),
                sampling_config=None,
                prompt_table_path='prompt_table.npy',
                max_new_tokens=max_new_tokens,
                end_id=end_id,
                pad_id=self.tokenizer.pad_token_id,
                top_k=self.args.top_k,
                num_beams=self.args.num_beams,
                output_sequence_lengths=False,
                return_dict=False)
        else:
            output_ids = self.model.generate(
                input_ids,
                self.decoder_input_ids,
                max_new_tokens,
                num_beams=self.args.num_beams,
                bos_token_id=self.tokenizer.bos_token_id,
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=self.tokenizer.eos_token_id,
                debug_mode=False,
                prompt_embedding_table=ptuning_args[0],
                prompt_tasks=ptuning_args[1],
                prompt_vocab_size=ptuning_args[2])

            # Reset input_lengths to match decoder_input_ids
            input_lengths = torch.ones(input_lengths.shape,
                                       dtype=input_lengths.dtype)
        profiler.stop("LLM")

        if tensorrt_llm.mpi_rank() == 0:
            # Extract a list of tensors of shape beam_width x output_ids.
            output_beams_list = [
                self.tokenizer.batch_decode(
                    output_ids[batch_idx, :, input_lengths[batch_idx]:],
                    skip_special_tokens=True)
                for batch_idx in range(self.args.batch_size)
            ]

            stripped_text = [[
                output_beams_list[batch_idx][beam_idx].strip()
                for beam_idx in range(self.args.num_beams)
            ] for batch_idx in range(self.args.batch_size)]
            profiler.stop("Generate")
            return stripped_text
        else:
            profiler.stop("Generate")
            return None

    def get_visual_features(self, image):
        visual_features = {'input': image.half()}
        visual_output_info = self.visual_encoder_session.infer_shapes(
            [TensorInfo('input', trt.DataType.HALF, image.shape)])
        visual_outputs = {
            t.name: torch.empty(tuple(t.shape),
                                dtype=trt_dtype_to_torch(t.dtype),
                                device="cuda")
            for t in visual_output_info
        }

        ok = self.visual_encoder_session.run(visual_features, visual_outputs,
                                             self.stream)
        assert ok, "Runtime execution failed for vit session"
        torch.cuda.synchronize()

        image_embeds = visual_outputs['output']
        image_atts = torch.ones(image_embeds.size()[:-1],
                                dtype=torch.long).to("cuda")

        return image_embeds, image_atts

    def setup_fake_prompts(self, visual_features, pre_input_ids, post_input_ids,
                           input_lengths):
        # Assemble fake prompts which points to image embedding actually
        fake_prompt_id = torch.arange(
            self.model_config.vocab_size,
            self.model_config.vocab_size +
            visual_features.shape[0] * visual_features.shape[1],
            device="cuda")
        fake_prompt_id = fake_prompt_id.reshape(visual_features.shape[0],
                                                visual_features.shape[1])

        if post_input_ids is not None:
            input_ids = [pre_input_ids, fake_prompt_id, post_input_ids]
        else:
            input_ids = [fake_prompt_id, pre_input_ids]
        input_ids = torch.cat(input_ids,
                              dim=1).contiguous().to(torch.int32).cuda()

        if self.args.decoder_llm or self.runtime_mapping.is_first_pp_rank():
            ptuning_args = self.ptuning_setup(visual_features, input_ids,
                                              input_lengths)
        else:
            ptuning_args = [None, None, None]

        return input_ids, ptuning_args

    def ptuning_setup(self, prompt_table, input_ids, input_lengths):
        if prompt_table is not None:
            task_vocab_size = torch.tensor([prompt_table.shape[1]],
                                           dtype=torch.int32,
                                           device="cuda")
            prompt_table = prompt_table.view(
                (prompt_table.shape[0] * prompt_table.shape[1],
                 prompt_table.shape[2]))

            hidden_size = self.model_config.hidden_size
            if not self.args.decoder_llm:
                hidden_size *= self.runtime_mapping.tp_size
            assert prompt_table.shape[
                1] == hidden_size, "Prompt table dimensions do not match hidden size"

            prompt_table = prompt_table.cuda().to(
                dtype=tensorrt_llm._utils.str_dtype_to_torch(
                    self.model_config.dtype))
        else:
            prompt_table = torch.empty([1, hidden_size]).cuda()
            task_vocab_size = torch.zeros([1]).cuda()

        if self.model_config.remove_input_padding:
            tasks = torch.zeros([torch.sum(input_lengths)],
                                dtype=torch.int32).cuda()
            if args.decoder_llm: tasks = tasks.unsqueeze(0)
        else:
            tasks = torch.zeros(input_ids.shape, dtype=torch.int32).cuda()

        return [prompt_table, tasks, task_vocab_size]

def load_test_image(model_name):
    if "nougat" in model_name:
        filepath = hf_hub_download(
            repo_id="hf-internal-testing/fixtures_docvqa",
            filename="nougat_paper.png",
            repo_type="dataset")
        image = Image.open(filepath)
    else:
        img_url = 'https://storage.googleapis.com/sfr-vision-language-research/LAVIS/assets/merlion.png'
        image = Image.open(requests.get(img_url,
                                        stream=True).raw).convert('RGB')

    return image
def load_image(filename):
    if isinstance(filename,Path):
        image = Image.open(str(filename)).convert('RGB')
    else:
        image = load_test_image(filename)
    return image

if __name__ == '__main__':
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    args = parse_arguments()
    torch.manual_seed(1234)
    tensorrt_llm.logger.set_level(args.log_level)
    runtime_rank = tensorrt_llm.mpi_rank()

    filename = Path(__file__).parent.parent.joinpath('view.jpg')
    image = load_image(filename) if filename.exists() else load_image(args.hf_model_dir)

    processor = AutoProcessor.from_pretrained(args.hf_model_dir)
    image = processor(text=args.input_text,
                      images=image,
                      return_tensors="pt")['pixel_values']
    image = image.half().to("cuda")

    pre_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\n"
    post_prompt = args.input_text + " ASSISTANT:"

    model = MultiModalModel(args)

    num_iters = 100
    for _ in range(num_iters):
        stripped_text = model.generate(pre_prompt, post_prompt, image,
                                       args.max_new_tokens)

    if runtime_rank == 0:
        logger.info("---------------------------------------------------------")
        logger.info(f"\n[Q] {args.input_text}")
        logger.info(f"\n[A] {stripped_text}")
        logger.info(
            f'TensorRT vision encoder latency: {profiler.elapsed_time_in_sec("Vision") / num_iters} sec'
        )
        logger.info(
            f'TensorRT-LLM LLM latency: {profiler.elapsed_time_in_sec("LLM") / num_iters} sec'
        )
        logger.info(
            f'Generate latency: {profiler.elapsed_time_in_sec("Generate") / num_iters} sec'
        )
        logger.info("---------------------------------------------------------")

run_demo.sh

#!/bin/bash
MODEL_NAME=llava-v1.5-7b-hf
LLAVA_MODEL=$HOME/huggingface/$MODEL_NAME
CHECKPOINT_PATH=models/trt_models/$MODEL_NAME/fp16/1-gpu
ENGINE_PATH=models/trt_engines/${MODEL_NAME}/fp16/1-gpu
VISUAL_ENGINE=models/visual_engines
MAX_BATCH_SIZE=1
DTYPE=float16
MAX_MULTIMODAL_LEN=$(($MAX_BATCH_SIZE * 576))
# PROMPT="Question: which city is this? Answer:"
# PROMPT="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\n详细描述这张图片,用中文回答,要求字数接近一 百二十个汉字。 ASSISTANT:"
PROMPT="详细描述这张图片,用中文回答,要求字数接近一百二十个汉字。"
printf "\033[32mExtract hf model's' to checkpoint!\n\033[0m"
if [ ! -d $CHECKPOINT_PATH ]; then
        python convert_checkpoint.py \
                --model_dir $LLAVA_MODEL \
                --output_dir $CHECKPOINT_PATH \
                --dtype $DTYPE
fi

printf "\033[32mTry to convert checkpoint to tensorrt engine!\n\033[0m"
if [ ! -d $ENGINE_PATH ]; then
        trtllm-build \
                --checkpoint_dir $CHECKPOINT_PATH \
                --output_dir $ENGINE_PATH \
                --gpt_attention_plugin $DTYPE \
                --gemm_plugin float16 \
                --max_batch_size $MAX_BATCH_SIZE \
                --max_input_len 2048 \
                --max_output_len 512 \
                --max_multimodal_len $MAX_MULTIMODAL_LEN # 1 (max_batch_size) * 576 (num_visual_features)
fi
printf "\033[32mTry to convert checkpoint to tensorrt engine!\n\033[0m"
if [ ! -d $VISUAL_ENGINE ]; then
        python build_visual_engine.py --model_name ${MODEL_NAME} --model_path $LLAVA_MODEL --output_dir $VISUAL_ENGINE
fi
printf "\033[32mRUN Engine!\n\033[0m"
export PYTHONPATH=$HOME/workspace/TensorRT-LLM/examples:$PYTHONPATH

case $script in
'hub_demo.py')
        script=hub_demo.py
        ;;
*)
        script=llava.py
        ;;
esac

python "$script" \
        --max_new_tokens 200 \
        --input_text "$PROMPT" \
        --hf_model_dir $LLAVA_MODEL \
        --visual_engine_dir $VISUAL_ENGINE/${MODEL_NAME} \
        --llm_engine_dir $ENGINE_PATH \
        --top_k 50 \
        --decoder_llm

pipeline with transformers(llava_demo.py):

import requests
from PIL import Image
import time
from pathlib import Path
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration

# model_id = "llava-hf/llava-1.5-7b-hf"
model_id = "/home/user/huggingface/llava-v1.5-7b-hf"

# prompt = "USER: <image>\nWhat are these?\nASSISTANT:"
torch.manual_seed(1234)
prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\n详细描述这张图片,用中文回答,要求字数接近一 百二十个汉字。 ASSISTANT:"
# image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
image_file = Path(__file__).parent.joinpath('view.jpg')

model = LlavaForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
).to(0)

processor = AutoProcessor.from_pretrained(model_id)

raw_image = Image.open(image_file)
inputs = processor(prompt, raw_image, return_tensors='pt').to(0, torch.float16)
t0= time.time()
output = model.generate(**inputs, max_new_tokens=200, do_sample=True)
t1 = time.time()
print(f"cost time:{t1-t0:.5f}")
mesg  = processor.decode(output[0][2:], skip_special_tokens=True)
print(mesg)

Although the above code may take around 10 seconds to test a single inference on my machine,Perhaps it was an accidental result, as my previous testing took about 5 seconds.Although there is significant fluctuation in a single test, it can stabilize to around 5 seconds after a large number of tests。Why is llava trt-llm not much faster than transformers? image

During operation, the GPU utilization rate did not reach 100%

Gutianpei commented 6 months ago

Same issue

symphonylyh commented 6 months ago

cross-ref #1118

Hi, there was a bug causing non-stopping behavior in python runtime in general. Can you please check #1118 and re-test the upcoming weekly release to verity this issue is gone? Thanks

kaiyux commented 6 months ago

cross-ref https://github.com/NVIDIA/TensorRT-LLM/issues/1118#issuecomment-1968045987

The main branch has been updated. Thanks.

bleedingfight commented 6 months ago
@kaiyux thanks for your update,I tried compiling the latest code in docker,tensorrt-llm:0.9.0.dev2024022700,ON 3090 The maximum batch used when building the engine file is 8. run.py batch=1,2,4,8: My Test: batch 0.9.0.dev2024022700(s) 0.9.0.dev2024020600(s)
1 4.064309024810791 4.487630772590637
2 4.015249347686767 4.456121301651001
4 4.471942019462586 4.894955468177796
8 5.405202007293701 5.83393361568451
amukkara commented 6 months ago

Hi @bleedingfight What is the actual number of generated tokens for transformers and TRT-LLM?

Although max_new_tokens is set to 200, transformers output has 22 tokens only while TRT-LLM output has 200 tokens. When both TRT-LLM and transformers generate same number of tokens (by setting max_new_tokens=22), TRT-LLM is 5x faster.

bleedingfight commented 6 months ago

@amukkara In my application scenario, we need to output image descriptions for many images.So the max_new_token for our output is the same for different images. I can indeed improve performance by reducing the value of max_new_token,but this is only because the number of iterations output by the auto-regressive model has decreased. I think the optimization of TRT-llm is not deep enough.

amukkara commented 6 months ago

@bleedingfight Using higher max_new_tokens is not an issue. But we should ensure TRT-LLM and transformers generate similar output length to make fair comparison. You can change options to generate() like top_k, min_new_tokens and replicate them in transformers script. Once the output token counts are similar, TRT-LLM should be much faster than transformers.

bleedingfight commented 6 months ago

@amukkara How to maintain consistent output between trt-llm and transformers,I use the generate interface in Transformers with max_new_tokens=200,topk=50 and same prompt,The output results of transformers and trt-llm are different.I want to know how NVIDIA officially ensures the correctness of the TRT model output, but it seems difficult for developers themselves to know how NVIDIA ensures it.

XuecaiHu commented 5 months ago

@amukkara How to maintain consistent output between trt-llm and transformers,I use the generate interface in Transformers with max_new_tokens=200,topk=50 and same prompt,The output results of transformers and trt-llm are different.I want to know how NVIDIA officially ensures the correctness of the TRT model output, but it seems difficult for developers themselves to know how NVIDIA ensures it.

same question. We test on qwen-vl-chat. Have you solved this problem yet? Looking forward to any suggestions.