Closed Godlovecui closed 4 months ago
build engine with --max_batch_size > 1
use the flag =trtllm.BatchingType.INFLIGHT
trtllm.ExecutorConfig(
1,
parallel_config=trt_parallel_config,
normalize_log_probs=False,
batching_type=trtllm.BatchingType.INFLIGHT,
scheduler_config=trt_scheduler_config,
),
Here is an example for tensorrt_llm==0.10.0.dev2024050700
import argparse
import logging
import time
from datetime import datetime, timedelta
from pathlib import Path
from threading import Thread
import tensorrt_llm
import tensorrt_llm.bindings.executor as trtllm
from transformers import PreTrainedTokenizerFast
logger = logging.getLogger(__name__)
def tensorrt_llm_executor_worker_path() -> str:
worker_path = Path(tensorrt_llm.__file__).parent / 'bin' / 'executorWorker'
if not worker_path.exists():
raise Exception("TensorRT-LLM executor worker not found")
return str(worker_path)
def get_trt_parallel_config():
world_size = 2
if world_size > 1:
executor_worker_path = tensorrt_llm_executor_worker_path()
orchestrator_config = trtllm.OrchestratorConfig(True, executor_worker_path)
return trtllm.ParallelConfig(
trtllm.CommunicationType.MPI,
trtllm.CommunicationMode.ORCHESTRATOR,
orchestrator_config=orchestrator_config,
# TODO:BIS fix device_ids
device_ids=[0, 1],
)
else:
return trtllm.ParallelConfig(trtllm.CommunicationType.MPI, trtllm.CommunicationMode.LEADER)
def create_executor(model_path: str) -> trtllm.Executor:
trt_parallel_config = get_trt_parallel_config()
trt_scheduler_config = trtllm.SchedulerConfig(trtllm.CapacitySchedulerPolicy.GUARANTEED_NO_EVICT)
return trtllm.Executor(
Path(model_path),
trtllm.ModelType.DECODER_ONLY,
trtllm.ExecutorConfig(
1,
parallel_config=trt_parallel_config,
normalize_log_probs=False,
batching_type=trtllm.BatchingType.INFLIGHT,
scheduler_config=trt_scheduler_config,
),
)
def create_request(input_ids, output_len, eos_id: int, sample_params):
output_config = trtllm.OutputConfig(exclude_input_from_output=True)
## This seems to somewhat resolve the issue
# sampling_config = trtllm.SamplingConfig(beam_width=1, frequency_penalty=1.0)
request = trtllm.Request(
input_token_ids=input_ids,
max_new_tokens=output_len,
streaming=True,
output_config=output_config,
end_id=eos_id,
sampling_config=sample_params,
)
return request
trt_id = None
def main():
default_prompt = "You have been tasked with designing a solar-powered water heating system for a residential building. Describe the key components and considerations you would include in your design. Design a five-step workflow.!"
# default_prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\nYou have been tasked with designing a solar-powered water heating system for a residential building. Describe the key components and considerations you would include in your design. Design a five-step workflow.!<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", required=False, default="./tmp/llama3-8b-tp2-engine")
parser.add_argument("--tokenizer_path", required=False, default="/home/scratch.trt_llm_data/llm-models/llama-models-v3/llama-v3-8b-instruct-hf/")
parser.add_argument("--prompt", required=False, default=default_prompt)
args = parser.parse_args()
tokenizer = PreTrainedTokenizerFast.from_pretrained(args.tokenizer_path)
executor = create_executor(args.model_path)
prompt = args.prompt
prompt_ids = tokenizer.encode(prompt)
print(prompt_ids)
def do_decode(sampling_config):
output_ids = []
finished = False
req = create_request(prompt_ids, 150, tokenizer.eos_token_id, sampling_config)
_ = executor.enqueue_request(req)
while not finished:
responses = executor.await_responses(timeout=timedelta(seconds=1))
for r in responses:
if r.has_error():
raise RuntimeError(r.error_msg)
result = r.result
output_ids.extend(result.output_token_ids[0])
if result.is_final:
finished = True
return tokenizer.decode(output_ids)
print(do_decode(trtllm.SamplingConfig(beam_width=1, top_k=1, random_seed=1234)))
print("===================================")
executor.shutdown()
if __name__ == "__main__":
main()
- build engine with --max_batch_size > 1
- use the flag
=trtllm.BatchingType.INFLIGHT
trtllm.ExecutorConfig( 1, parallel_config=trt_parallel_config, normalize_log_probs=False, batching_type=trtllm.BatchingType.INFLIGHT, scheduler_config=trt_scheduler_config, ),
Here is an example for tensorrt_llm==0.10.0.dev2024050700
import argparse import logging import time from datetime import datetime, timedelta from pathlib import Path from threading import Thread import tensorrt_llm import tensorrt_llm.bindings.executor as trtllm from transformers import PreTrainedTokenizerFast logger = logging.getLogger(__name__) def tensorrt_llm_executor_worker_path() -> str: worker_path = Path(tensorrt_llm.__file__).parent / 'bin' / 'executorWorker' if not worker_path.exists(): raise Exception("TensorRT-LLM executor worker not found") return str(worker_path) def get_trt_parallel_config(): world_size = 2 if world_size > 1: executor_worker_path = tensorrt_llm_executor_worker_path() orchestrator_config = trtllm.OrchestratorConfig(True, executor_worker_path) return trtllm.ParallelConfig( trtllm.CommunicationType.MPI, trtllm.CommunicationMode.ORCHESTRATOR, orchestrator_config=orchestrator_config, # TODO:BIS fix device_ids device_ids=[0, 1], ) else: return trtllm.ParallelConfig(trtllm.CommunicationType.MPI, trtllm.CommunicationMode.LEADER) def create_executor(model_path: str) -> trtllm.Executor: trt_parallel_config = get_trt_parallel_config() trt_scheduler_config = trtllm.SchedulerConfig(trtllm.CapacitySchedulerPolicy.GUARANTEED_NO_EVICT) return trtllm.Executor( Path(model_path), trtllm.ModelType.DECODER_ONLY, trtllm.ExecutorConfig( 1, parallel_config=trt_parallel_config, normalize_log_probs=False, batching_type=trtllm.BatchingType.INFLIGHT, scheduler_config=trt_scheduler_config, ), ) def create_request(input_ids, output_len, eos_id: int, sample_params): output_config = trtllm.OutputConfig(exclude_input_from_output=True) ## This seems to somewhat resolve the issue # sampling_config = trtllm.SamplingConfig(beam_width=1, frequency_penalty=1.0) request = trtllm.Request( input_token_ids=input_ids, max_new_tokens=output_len, streaming=True, output_config=output_config, end_id=eos_id, sampling_config=sample_params, ) return request trt_id = None def main(): default_prompt = "You have been tasked with designing a solar-powered water heating system for a residential building. Describe the key components and considerations you would include in your design. Design a five-step workflow.!" # default_prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\nYou have been tasked with designing a solar-powered water heating system for a residential building. Describe the key components and considerations you would include in your design. Design a five-step workflow.!<|eot_id|><|start_header_id|>assistant<|end_header_id|>" parser = argparse.ArgumentParser() parser.add_argument("--model_path", required=False, default="./tmp/llama3-8b-tp2-engine") parser.add_argument("--tokenizer_path", required=False, default="/home/scratch.trt_llm_data/llm-models/llama-models-v3/llama-v3-8b-instruct-hf/") parser.add_argument("--prompt", required=False, default=default_prompt) args = parser.parse_args() tokenizer = PreTrainedTokenizerFast.from_pretrained(args.tokenizer_path) executor = create_executor(args.model_path) prompt = args.prompt prompt_ids = tokenizer.encode(prompt) print(prompt_ids) def do_decode(sampling_config): output_ids = [] finished = False req = create_request(prompt_ids, 150, tokenizer.eos_token_id, sampling_config) _ = executor.enqueue_request(req) while not finished: responses = executor.await_responses(timeout=timedelta(seconds=1)) for r in responses: if r.has_error(): raise RuntimeError(r.error_msg) result = r.result output_ids.extend(result.output_token_ids[0]) if result.is_final: finished = True return tokenizer.decode(output_ids) print(do_decode(trtllm.SamplingConfig(beam_width=1, top_k=1, random_seed=1234))) print("===================================") executor.shutdown() if __name__ == "__main__": main()
Can you tell me the path of this file? I can not find it. There are two steps, the first step is "python convert_checkpoint.py xxx", the second step is "trtllm-build xxx". Then to develop it in Triton server. Which step would I need to modify this file? Thank you~ @hijkzzz
- build engine with --max_batch_size > 1
- use the flag
=trtllm.BatchingType.INFLIGHT
trtllm.ExecutorConfig( 1, parallel_config=trt_parallel_config, normalize_log_probs=False, batching_type=trtllm.BatchingType.INFLIGHT, scheduler_config=trt_scheduler_config, ),
Here is an example for tensorrt_llm==0.10.0.dev2024050700
import argparse import logging import time from datetime import datetime, timedelta from pathlib import Path from threading import Thread import tensorrt_llm import tensorrt_llm.bindings.executor as trtllm from transformers import PreTrainedTokenizerFast logger = logging.getLogger(__name__) def tensorrt_llm_executor_worker_path() -> str: worker_path = Path(tensorrt_llm.__file__).parent / 'bin' / 'executorWorker' if not worker_path.exists(): raise Exception("TensorRT-LLM executor worker not found") return str(worker_path) def get_trt_parallel_config(): world_size = 2 if world_size > 1: executor_worker_path = tensorrt_llm_executor_worker_path() orchestrator_config = trtllm.OrchestratorConfig(True, executor_worker_path) return trtllm.ParallelConfig( trtllm.CommunicationType.MPI, trtllm.CommunicationMode.ORCHESTRATOR, orchestrator_config=orchestrator_config, # TODO:BIS fix device_ids device_ids=[0, 1], ) else: return trtllm.ParallelConfig(trtllm.CommunicationType.MPI, trtllm.CommunicationMode.LEADER) def create_executor(model_path: str) -> trtllm.Executor: trt_parallel_config = get_trt_parallel_config() trt_scheduler_config = trtllm.SchedulerConfig(trtllm.CapacitySchedulerPolicy.GUARANTEED_NO_EVICT) return trtllm.Executor( Path(model_path), trtllm.ModelType.DECODER_ONLY, trtllm.ExecutorConfig( 1, parallel_config=trt_parallel_config, normalize_log_probs=False, batching_type=trtllm.BatchingType.INFLIGHT, scheduler_config=trt_scheduler_config, ), ) def create_request(input_ids, output_len, eos_id: int, sample_params): output_config = trtllm.OutputConfig(exclude_input_from_output=True) ## This seems to somewhat resolve the issue # sampling_config = trtllm.SamplingConfig(beam_width=1, frequency_penalty=1.0) request = trtllm.Request( input_token_ids=input_ids, max_new_tokens=output_len, streaming=True, output_config=output_config, end_id=eos_id, sampling_config=sample_params, ) return request trt_id = None def main(): default_prompt = "You have been tasked with designing a solar-powered water heating system for a residential building. Describe the key components and considerations you would include in your design. Design a five-step workflow.!" # default_prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\nYou have been tasked with designing a solar-powered water heating system for a residential building. Describe the key components and considerations you would include in your design. Design a five-step workflow.!<|eot_id|><|start_header_id|>assistant<|end_header_id|>" parser = argparse.ArgumentParser() parser.add_argument("--model_path", required=False, default="./tmp/llama3-8b-tp2-engine") parser.add_argument("--tokenizer_path", required=False, default="/home/scratch.trt_llm_data/llm-models/llama-models-v3/llama-v3-8b-instruct-hf/") parser.add_argument("--prompt", required=False, default=default_prompt) args = parser.parse_args() tokenizer = PreTrainedTokenizerFast.from_pretrained(args.tokenizer_path) executor = create_executor(args.model_path) prompt = args.prompt prompt_ids = tokenizer.encode(prompt) print(prompt_ids) def do_decode(sampling_config): output_ids = [] finished = False req = create_request(prompt_ids, 150, tokenizer.eos_token_id, sampling_config) _ = executor.enqueue_request(req) while not finished: responses = executor.await_responses(timeout=timedelta(seconds=1)) for r in responses: if r.has_error(): raise RuntimeError(r.error_msg) result = r.result output_ids.extend(result.output_token_ids[0]) if result.is_final: finished = True return tokenizer.decode(output_ids) print(do_decode(trtllm.SamplingConfig(beam_width=1, top_k=1, random_seed=1234))) print("===================================") executor.shutdown() if __name__ == "__main__": main()
Can you tell me the path of this file? I can not find it. There are two steps, the first step is "python convert_checkpoint.py xxx", the second step is "trtllm-build xxx". Then to develop it in Triton server. Which step would I need to modify this file? Thank you~ @hijkzzz
you only need to compile the engine with --max_batch_size > 1.
System Info
RTX-8*4090
Who can help?
@kaiyux @ncomly-nvidia @jun
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
I want to test the inflight batching feature, how to open it in build engine. "--use_inflight_batching" has been removed in trtllm-build in v0.9.0.
Expected behavior
Inflight batching is opened correctly.
actual behavior
I am not sure how to open it in v0.9.0.
additional notes
I want to test the inflight batching feature, how to open it in build engine. "--use_inflight_batching" has been removed in trtllm-build in v0.9.0.