NVIDIA / TensorRT-LLM

TensorRT-LLM provides users with an easy-to-use Python API to define Large Language Models (LLMs) and build TensorRT engines that contain state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs. TensorRT-LLM also contains components to create Python and C++ runtimes that execute those TensorRT engines.
https://nvidia.github.io/TensorRT-LLM
Apache License 2.0
8.49k stars 962 forks source link

How to open inflight batching in TensorRT_LLM in v0.9.0 #1729

Closed Godlovecui closed 4 months ago

Godlovecui commented 4 months ago

System Info

RTX-8*4090

Who can help?

@kaiyux @ncomly-nvidia @jun

Information

Tasks

Reproduction

I want to test the inflight batching feature, how to open it in build engine. "--use_inflight_batching" has been removed in trtllm-build in v0.9.0.

Expected behavior

Inflight batching is opened correctly.

actual behavior

I am not sure how to open it in v0.9.0.

additional notes

I want to test the inflight batching feature, how to open it in build engine. "--use_inflight_batching" has been removed in trtllm-build in v0.9.0.

hijkzzz commented 4 months ago
  1. build engine with --max_batch_size > 1

  2. use the flag =trtllm.BatchingType.INFLIGHT

    trtllm.ExecutorConfig( 
            1, 
            parallel_config=trt_parallel_config, 
            normalize_log_probs=False, 
            batching_type=trtllm.BatchingType.INFLIGHT, 
            scheduler_config=trt_scheduler_config, 
        ), 

Here is an example for tensorrt_llm==0.10.0.dev2024050700

import argparse 
import logging 
import time 
from datetime import datetime, timedelta 
from pathlib import Path 
from threading import Thread 

import tensorrt_llm 
import tensorrt_llm.bindings.executor as trtllm 
from transformers import PreTrainedTokenizerFast 

logger = logging.getLogger(__name__) 

def tensorrt_llm_executor_worker_path() -> str: 
    worker_path = Path(tensorrt_llm.__file__).parent / 'bin' / 'executorWorker' 
    if not worker_path.exists(): 
        raise Exception("TensorRT-LLM executor worker not found") 
    return str(worker_path) 

def get_trt_parallel_config(): 
    world_size = 2 
    if world_size > 1: 
        executor_worker_path = tensorrt_llm_executor_worker_path() 
        orchestrator_config = trtllm.OrchestratorConfig(True, executor_worker_path) 
        return trtllm.ParallelConfig( 
            trtllm.CommunicationType.MPI, 
            trtllm.CommunicationMode.ORCHESTRATOR, 
            orchestrator_config=orchestrator_config, 
            # TODO:BIS fix device_ids 
            device_ids=[0, 1], 
        ) 
    else: 
        return trtllm.ParallelConfig(trtllm.CommunicationType.MPI, trtllm.CommunicationMode.LEADER) 

def create_executor(model_path: str) -> trtllm.Executor: 
    trt_parallel_config = get_trt_parallel_config() 
    trt_scheduler_config = trtllm.SchedulerConfig(trtllm.CapacitySchedulerPolicy.GUARANTEED_NO_EVICT) 

    return trtllm.Executor( 
        Path(model_path), 
        trtllm.ModelType.DECODER_ONLY, 
        trtllm.ExecutorConfig( 
            1, 
            parallel_config=trt_parallel_config, 
            normalize_log_probs=False, 
            batching_type=trtllm.BatchingType.INFLIGHT, 
            scheduler_config=trt_scheduler_config, 
        ), 
    ) 

def create_request(input_ids, output_len, eos_id: int, sample_params): 
    output_config = trtllm.OutputConfig(exclude_input_from_output=True) 
    ## This seems to somewhat resolve the issue 
    # sampling_config = trtllm.SamplingConfig(beam_width=1, frequency_penalty=1.0) 
    request = trtllm.Request( 
        input_token_ids=input_ids, 
        max_new_tokens=output_len, 
        streaming=True, 
        output_config=output_config, 
        end_id=eos_id, 
        sampling_config=sample_params, 
    ) 
    return request 

trt_id = None 

def main(): 
    default_prompt = "You have been tasked with designing a solar-powered water heating system for a residential building. Describe the key components and considerations you would include in your design. Design a five-step workflow.!" 
    # default_prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\nYou have been tasked with designing a solar-powered water heating system for a residential building. Describe the key components and considerations you would include in your design. Design a five-step workflow.!<|eot_id|><|start_header_id|>assistant<|end_header_id|>" 
    parser = argparse.ArgumentParser() 
    parser.add_argument("--model_path", required=False, default="./tmp/llama3-8b-tp2-engine") 
    parser.add_argument("--tokenizer_path", required=False, default="/home/scratch.trt_llm_data/llm-models/llama-models-v3/llama-v3-8b-instruct-hf/") 
    parser.add_argument("--prompt", required=False, default=default_prompt) 

    args = parser.parse_args() 

    tokenizer = PreTrainedTokenizerFast.from_pretrained(args.tokenizer_path) 
    executor = create_executor(args.model_path) 
    prompt = args.prompt 
    prompt_ids = tokenizer.encode(prompt) 
    print(prompt_ids) 

    def do_decode(sampling_config): 
        output_ids = [] 
        finished = False 
        req = create_request(prompt_ids, 150, tokenizer.eos_token_id, sampling_config) 
        _ = executor.enqueue_request(req) 
        while not finished: 
            responses = executor.await_responses(timeout=timedelta(seconds=1)) 
            for r in responses: 
                if r.has_error(): 
                    raise RuntimeError(r.error_msg) 
                result = r.result 
                output_ids.extend(result.output_token_ids[0]) 
                if result.is_final: 
                    finished = True 
        return tokenizer.decode(output_ids) 

    print(do_decode(trtllm.SamplingConfig(beam_width=1, top_k=1, random_seed=1234)))     
    print("===================================") 

    executor.shutdown() 

if __name__ == "__main__": 
    main() 
Godlovecui commented 4 months ago
  1. build engine with --max_batch_size > 1
  2. use the flag =trtllm.BatchingType.INFLIGHT
 trtllm.ExecutorConfig( 
            1, 
            parallel_config=trt_parallel_config, 
            normalize_log_probs=False, 
            batching_type=trtllm.BatchingType.INFLIGHT, 
            scheduler_config=trt_scheduler_config, 
        ), 

Here is an example for tensorrt_llm==0.10.0.dev2024050700

import argparse 
import logging 
import time 
from datetime import datetime, timedelta 
from pathlib import Path 
from threading import Thread 

import tensorrt_llm 
import tensorrt_llm.bindings.executor as trtllm 
from transformers import PreTrainedTokenizerFast 

logger = logging.getLogger(__name__) 

def tensorrt_llm_executor_worker_path() -> str: 
    worker_path = Path(tensorrt_llm.__file__).parent / 'bin' / 'executorWorker' 
    if not worker_path.exists(): 
        raise Exception("TensorRT-LLM executor worker not found") 
    return str(worker_path) 

def get_trt_parallel_config(): 
    world_size = 2 
    if world_size > 1: 
        executor_worker_path = tensorrt_llm_executor_worker_path() 
        orchestrator_config = trtllm.OrchestratorConfig(True, executor_worker_path) 
        return trtllm.ParallelConfig( 
            trtllm.CommunicationType.MPI, 
            trtllm.CommunicationMode.ORCHESTRATOR, 
            orchestrator_config=orchestrator_config, 
            # TODO:BIS fix device_ids 
            device_ids=[0, 1], 
        ) 
    else: 
        return trtllm.ParallelConfig(trtllm.CommunicationType.MPI, trtllm.CommunicationMode.LEADER) 

def create_executor(model_path: str) -> trtllm.Executor: 
    trt_parallel_config = get_trt_parallel_config() 
    trt_scheduler_config = trtllm.SchedulerConfig(trtllm.CapacitySchedulerPolicy.GUARANTEED_NO_EVICT) 

    return trtllm.Executor( 
        Path(model_path), 
        trtllm.ModelType.DECODER_ONLY, 
        trtllm.ExecutorConfig( 
            1, 
            parallel_config=trt_parallel_config, 
            normalize_log_probs=False, 
            batching_type=trtllm.BatchingType.INFLIGHT, 
            scheduler_config=trt_scheduler_config, 
        ), 
    ) 

def create_request(input_ids, output_len, eos_id: int, sample_params): 
    output_config = trtllm.OutputConfig(exclude_input_from_output=True) 
    ## This seems to somewhat resolve the issue 
    # sampling_config = trtllm.SamplingConfig(beam_width=1, frequency_penalty=1.0) 
    request = trtllm.Request( 
        input_token_ids=input_ids, 
        max_new_tokens=output_len, 
        streaming=True, 
        output_config=output_config, 
        end_id=eos_id, 
        sampling_config=sample_params, 
    ) 
    return request 

trt_id = None 

def main(): 
    default_prompt = "You have been tasked with designing a solar-powered water heating system for a residential building. Describe the key components and considerations you would include in your design. Design a five-step workflow.!" 
    # default_prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\nYou have been tasked with designing a solar-powered water heating system for a residential building. Describe the key components and considerations you would include in your design. Design a five-step workflow.!<|eot_id|><|start_header_id|>assistant<|end_header_id|>" 
    parser = argparse.ArgumentParser() 
    parser.add_argument("--model_path", required=False, default="./tmp/llama3-8b-tp2-engine") 
    parser.add_argument("--tokenizer_path", required=False, default="/home/scratch.trt_llm_data/llm-models/llama-models-v3/llama-v3-8b-instruct-hf/") 
    parser.add_argument("--prompt", required=False, default=default_prompt) 

    args = parser.parse_args() 

    tokenizer = PreTrainedTokenizerFast.from_pretrained(args.tokenizer_path) 
    executor = create_executor(args.model_path) 
    prompt = args.prompt 
    prompt_ids = tokenizer.encode(prompt) 
    print(prompt_ids) 

    def do_decode(sampling_config): 
        output_ids = [] 
        finished = False 
        req = create_request(prompt_ids, 150, tokenizer.eos_token_id, sampling_config) 
        _ = executor.enqueue_request(req) 
        while not finished: 
            responses = executor.await_responses(timeout=timedelta(seconds=1)) 
            for r in responses: 
                if r.has_error(): 
                    raise RuntimeError(r.error_msg) 
                result = r.result 
                output_ids.extend(result.output_token_ids[0]) 
                if result.is_final: 
                    finished = True 
        return tokenizer.decode(output_ids) 

    print(do_decode(trtllm.SamplingConfig(beam_width=1, top_k=1, random_seed=1234)))     
    print("===================================") 

    executor.shutdown() 

if __name__ == "__main__": 
    main() 

Can you tell me the path of this file? I can not find it. There are two steps, the first step is "python convert_checkpoint.py xxx", the second step is "trtllm-build xxx". Then to develop it in Triton server. Which step would I need to modify this file? Thank you~ @hijkzzz

hijkzzz commented 4 months ago
  1. build engine with --max_batch_size > 1
  2. use the flag =trtllm.BatchingType.INFLIGHT
 trtllm.ExecutorConfig( 
            1, 
            parallel_config=trt_parallel_config, 
            normalize_log_probs=False, 
            batching_type=trtllm.BatchingType.INFLIGHT, 
            scheduler_config=trt_scheduler_config, 
        ), 

Here is an example for tensorrt_llm==0.10.0.dev2024050700

import argparse 
import logging 
import time 
from datetime import datetime, timedelta 
from pathlib import Path 
from threading import Thread 

import tensorrt_llm 
import tensorrt_llm.bindings.executor as trtllm 
from transformers import PreTrainedTokenizerFast 

logger = logging.getLogger(__name__) 

def tensorrt_llm_executor_worker_path() -> str: 
    worker_path = Path(tensorrt_llm.__file__).parent / 'bin' / 'executorWorker' 
    if not worker_path.exists(): 
        raise Exception("TensorRT-LLM executor worker not found") 
    return str(worker_path) 

def get_trt_parallel_config(): 
    world_size = 2 
    if world_size > 1: 
        executor_worker_path = tensorrt_llm_executor_worker_path() 
        orchestrator_config = trtllm.OrchestratorConfig(True, executor_worker_path) 
        return trtllm.ParallelConfig( 
            trtllm.CommunicationType.MPI, 
            trtllm.CommunicationMode.ORCHESTRATOR, 
            orchestrator_config=orchestrator_config, 
            # TODO:BIS fix device_ids 
            device_ids=[0, 1], 
        ) 
    else: 
        return trtllm.ParallelConfig(trtllm.CommunicationType.MPI, trtllm.CommunicationMode.LEADER) 

def create_executor(model_path: str) -> trtllm.Executor: 
    trt_parallel_config = get_trt_parallel_config() 
    trt_scheduler_config = trtllm.SchedulerConfig(trtllm.CapacitySchedulerPolicy.GUARANTEED_NO_EVICT) 

    return trtllm.Executor( 
        Path(model_path), 
        trtllm.ModelType.DECODER_ONLY, 
        trtllm.ExecutorConfig( 
            1, 
            parallel_config=trt_parallel_config, 
            normalize_log_probs=False, 
            batching_type=trtllm.BatchingType.INFLIGHT, 
            scheduler_config=trt_scheduler_config, 
        ), 
    ) 

def create_request(input_ids, output_len, eos_id: int, sample_params): 
    output_config = trtllm.OutputConfig(exclude_input_from_output=True) 
    ## This seems to somewhat resolve the issue 
    # sampling_config = trtllm.SamplingConfig(beam_width=1, frequency_penalty=1.0) 
    request = trtllm.Request( 
        input_token_ids=input_ids, 
        max_new_tokens=output_len, 
        streaming=True, 
        output_config=output_config, 
        end_id=eos_id, 
        sampling_config=sample_params, 
    ) 
    return request 

trt_id = None 

def main(): 
    default_prompt = "You have been tasked with designing a solar-powered water heating system for a residential building. Describe the key components and considerations you would include in your design. Design a five-step workflow.!" 
    # default_prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\nYou have been tasked with designing a solar-powered water heating system for a residential building. Describe the key components and considerations you would include in your design. Design a five-step workflow.!<|eot_id|><|start_header_id|>assistant<|end_header_id|>" 
    parser = argparse.ArgumentParser() 
    parser.add_argument("--model_path", required=False, default="./tmp/llama3-8b-tp2-engine") 
    parser.add_argument("--tokenizer_path", required=False, default="/home/scratch.trt_llm_data/llm-models/llama-models-v3/llama-v3-8b-instruct-hf/") 
    parser.add_argument("--prompt", required=False, default=default_prompt) 

    args = parser.parse_args() 

    tokenizer = PreTrainedTokenizerFast.from_pretrained(args.tokenizer_path) 
    executor = create_executor(args.model_path) 
    prompt = args.prompt 
    prompt_ids = tokenizer.encode(prompt) 
    print(prompt_ids) 

    def do_decode(sampling_config): 
        output_ids = [] 
        finished = False 
        req = create_request(prompt_ids, 150, tokenizer.eos_token_id, sampling_config) 
        _ = executor.enqueue_request(req) 
        while not finished: 
            responses = executor.await_responses(timeout=timedelta(seconds=1)) 
            for r in responses: 
                if r.has_error(): 
                    raise RuntimeError(r.error_msg) 
                result = r.result 
                output_ids.extend(result.output_token_ids[0]) 
                if result.is_final: 
                    finished = True 
        return tokenizer.decode(output_ids) 

    print(do_decode(trtllm.SamplingConfig(beam_width=1, top_k=1, random_seed=1234)))     
    print("===================================") 

    executor.shutdown() 

if __name__ == "__main__": 
    main() 

Can you tell me the path of this file? I can not find it. There are two steps, the first step is "python convert_checkpoint.py xxx", the second step is "trtllm-build xxx". Then to develop it in Triton server. Which step would I need to modify this file? Thank you~ @hijkzzz

you only need to compile the engine with --max_batch_size > 1.