triton-inference-server / pytriton

PyTriton is a Flask/FastAPI-like interface that simplifies Triton's deployment in Python environments.
https://triton-inference-server.github.io/pytriton/
Apache License 2.0
684 stars 45 forks source link

multi-gpu inference with pytriton got worse TPS #75

Open lionsheep24 opened 2 weeks ago

lionsheep24 commented 2 weeks ago

Description

I implemented multi-instance inference across 4 A100 GPUS by following this, but measured TPS via locust was bit lower, and the latency was even higher compared to single GPU in multiple instances. Is there any overhead in multi-gpu inference with pytriton, or some mistakes in my code?

To reproduce Let me share my code to reproduce.

# server
from pytriton.decorators import batch
from pytriton.model_config import ModelConfig, Tensor
from pytriton.triton import Triton
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, pipeline

NUM_WORKER=9
MAX_BATCH_SIZE: int =4
MAX_QUEUE_DELAY_MICROSECONDS: int = 200000

logger = logging.getLogger("pytriton")
torch_dtype = torch.float16
processor = AutoProcessor.from_pretrained(model_ckpt, language="ko", mode="transcribe")

class _InferFuncWrapper:
    def __init__(self, pipeline):
        self._pipeline = pipeline

    @batch
    def __call__(self, **inputs):
        try:
            (audio_array,) = inputs.values()
            audio_array_list = audio_array.tolist()  # [batch, array]
            inputs = [np.array(audio, dtype=np.float16) for audio in audio_array_list]
            logger.info(f"INPUTS TYPE : {inputs[0].dtype}")

            # inference
            inference_start = datetime.now()
            stt_result = self._pipeline(inputs)
            inference_end = datetime.now()
            inference_latency = (inference_end - inference_start).total_seconds()

            output = np.array([np.char.encode(result["text"], "utf-8")[np.newaxis, ...] for result in stt_result])

            batch_size: int = audio_array.shape[0]

        except Exception as e:
            error_traceback = traceback.format_exc()
            error_message: str = f"Exception: {repr(e)}\nTraceback: {error_traceback}"
            logger.error(f"SERVER ERROR : {error_message}")

        return {"stt_result": output}

def multi_device_factory(devices: List):
    infer_fns = []
    for device in devices:
        batch_model = AutoModelForSpeechSeq2Seq.from_pretrained(
            model_ckpt,
            torch_dtype=torch_dtype,
            low_cpu_mem_usage=True,
            use_safetensors=True,
            use_flash_attention_2=True,
        )
        batch_model.to(device)
        asr_pipeline = pipeline(
            "automatic-speech-recognition",
            model=batch_model,
            tokenizer=processor.tokenizer,
            feature_extractor=processor.feature_extractor,
            max_new_tokens=128,
            chunk_length_s=30,
            batch_size=MAX_BATCH_SIZE,
            return_timestamps=False,
            torch_dtype=torch_dtype,
            device=device,
            generate_kwargs={"language": "ko", "num_beams": 1, "do_sample": False},
        )
        for _ in range(NUM_WORKER):
            infer_fns.append(_InferFuncWrapper(pipeline=asr_pipeline))
            logger.info(f"LOAD MODEL TO DEVICE : {device}")
    return infer_fns

def main():
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument(
        "--max-batch-size",
        type=int,
        default=16,
        help="Batch size of request.",
        required=False,
    )
    parser.add_argument(
        "--verbose",
        action="store_true",
        default=False,
    )

    parser.add_argument(
        "--http_port",
        type=int,
        default=8000,
    )

    parser.add_argument(
        "--grpc_port",
        type=int,
        default=8001,
    ),

    parser.add_argument(
        "--metrics_port",
        type=int,
        default=8001,
    )

    args = parser.parse_args()

    log_level = logging.INFO
    logging.basicConfig(level=log_level, format="%(asctime)s - %(levelname)s - %(name)s: %(message)s")

    with Triton(
        config=TritonConfig(http_port=args.http_port, grpc_port=args.grpc_port, metrics_port=args.metrics_port)
    ) as triton:
        logger.info(f"Loading STT model with batch size : {MAX_BATCH_SIZE}")
        triton.bind(
            model_name="Whisper",
            infer_func=multi_device_factory(["cuda:0", "cuda:1", "cuda:2", "cuda:3"]),
            inputs=[
                Tensor(name="inputs", dtype=np.float32, shape=(-1,)),
            ],
            outputs=[
                Tensor(name="stt_result", dtype=bytes, shape=(-1,)),
            ],
            config=ModelConfig(
                max_batch_size=MAX_BATCH_SIZE,
                batcher=DynamicBatcher(max_queue_delay_microseconds=MAX_QUEUE_DELAY_MICROSECONDS),
            ),
            strict=True,
        )
        logger.info("Serving inference")
        triton.serve()
        logger.info("Pytriton is ready")

if __name__ == "__main__":
    main()
# client
from locust import HttpUser, task,
class PytritonUser(HttpUser):
    wait_time = constant_throughput(1)
    @task
    def getHFFlashAttentionBatchResult(self):
        url = 'myModelUrl'
        headers = {
            'Content-Type': 'application/json',
        }
        data_1s_pytriton = {
        "inputs": [
            {
                "name": "inputs",
                "shape": (1, len(audio_sample_list_1s)),
                "datatype": "FP32",
                "data": audio_sample_list_1s
                }
                ]
                }
        response = self.client.post(url, headers=headers, json=data_1s_pytriton)
lionsheep24 commented 2 weeks ago

P.S : When I set the number of model instances per GPU to 10, the models load onto the CPU instead, even though GPU memory seems sufficient. At 9 instances, GPU memory usage is ~4GB after loading and ~13GB during inference. With 80GB available, I expected to be able to scale further. Why am I unable to exceed 9 instances per GPU, and how can I fully utilize my GPU capacity? I don't understand why GPU mem was 4GB after loading, even though there are 9 models in single GPU. If I remember right, driver version of 515 prints correctly. (nvidia-smi prints over 40GB after loading 8 instances)

Please refer to below log, whisch was shown at 10 instances

2024-06-21 01:01:14,074 - INFO - pytriton.client.client: Patch ModelClient http
I0621 01:00:57.442950 294 pinned_memory_manager.cc:275] Pinned memory pool is created at '0x7fb756000000' with size 268435456
I0621 01:00:57.498719 294 cuda_memory_manager.cc:107] CUDA memory pool is created on device 0 with size 67108864
I0621 01:00:57.498735 294 cuda_memory_manager.cc:107] CUDA memory pool is created on device 1 with size 67108864
I0621 01:00:57.498743 294 cuda_memory_manager.cc:107] CUDA memory pool is created on device 2 with size 67108864
I0621 01:00:57.498748 294 cuda_memory_manager.cc:107] CUDA memory pool is created on device 3 with size 67108864
I0621 01:00:58.289261 294 server.cc:607]
+------------------+------+
| Repository Agent | Path |
+------------------+------+
+------------------+------+

I0621 01:00:58.289316 294 server.cc:634]
+---------+------+--------+
| Backend | Path | Config |
+---------+------+--------+
+---------+------+--------+

I0621 01:00:58.289329 294 server.cc:677]
+-------+---------+--------+
| Model | Version | Status |
+-------+---------+--------+
+-------+---------+--------+

I0621 01:00:58.470838 294 metrics.cc:877] Collecting metrics for GPU 0: NVIDIA A100-SXM4-80GB
I0621 01:00:58.470875 294 metrics.cc:877] Collecting metrics for GPU 1: NVIDIA A100-SXM4-80GB
I0621 01:00:58.470883 294 metrics.cc:877] Collecting metrics for GPU 2: NVIDIA A100-SXM4-80GB
I0621 01:00:58.470890 294 metrics.cc:877] Collecting metrics for GPU 3: NVIDIA A100-SXM4-80GB
I0621 01:00:58.485218 294 metrics.cc:770] Collecting CPU metrics
I0621 01:00:58.485396 294 tritonserver.cc:2508]
+----------------------------------+------------------------------------------+
| Option                           | Value                                    |
+----------------------------------+------------------------------------------+
| server_id                        | triton                                   |
| server_version                   | 2.43.0                                   |
| server_extensions                | classification sequence model_repository |
|                                  |  model_repository(unload_dependents) sch |
|                                  | edule_policy model_configuration system_ |
|                                  | shared_memory cuda_shared_memory binary_ |
|                                  | tensor_data parameters statistics trace  |
|                                  | logging                                  |
| model_repository_path[0]         | /root/.cache/pytriton/workspace_g0f4c2gw |
|                                  | /model-store                             |
| model_control_mode               | MODE_EXPLICIT                            |
| startup_models_0                 | *                                        |
| strict_model_config              | 0                                        |
| rate_limit                       | OFF                                      |
| pinned_memory_pool_byte_size     | 268435456                                |
| cuda_memory_pool_byte_size{0}    | 67108864                                 |
| cuda_memory_pool_byte_size{1}    | 67108864                                 |
| cuda_memory_pool_byte_size{2}    | 67108864                                 |
| cuda_memory_pool_byte_size{3}    | 67108864                                 |
| min_supported_compute_capability | 6.0                                      |
| strict_readiness                 | 1                                        |
| exit_timeout                     | 30                                       |
| cache_enabled                    | 0                                        |
+----------------------------------+------------------------------------------+

I0621 01:00:58.489452 294 grpc_server.cc:2519] Started GRPCInferenceService at 0.0.0.0:8001
I0621 01:00:58.489704 294 http_server.cc:4637] Started HTTPService at 0.0.0.0:10100
I0621 01:00:58.530939 294 http_server.cc:320] Started Metrics Service at 0.0.0.0:8002
I0621 01:01:14.079496 294 model_lifecycle.cc:469] loading: Whisper:1
I0621 01:01:15.470333 294 python_be.cc:2381] TRITONBACKEND_ModelInstanceInitialize: Whisper_0_0 (CPU device 0)
I0621 01:01:15.470406 294 python_be.cc:2381] TRITONBACKEND_ModelInstanceInitialize: Whisper_0_1 (CPU device 0)
I0621 01:01:15.470431 294 python_be.cc:2381] TRITONBACKEND_ModelInstanceInitialize: Whisper_0_2 (CPU device 0)
I0621 01:01:15.470478 294 python_be.cc:2381] TRITONBACKEND_ModelInstanceInitialize: Whisper_0_3 (CPU device 0)
I0621 01:01:15.470625 294 python_be.cc:2381] TRITONBACKEND_ModelInstanceInitialize: Whisper_0_4 (CPU device 0)
I0621 01:01:15.470678 294 python_be.cc:2381] TRITONBACKEND_ModelInstanceInitialize: Whisper_0_5 (CPU device 0)
I0621 01:01:15.470725 294 python_be.cc:2381] TRITONBACKEND_ModelInstanceInitialize: Whisper_0_6 (CPU device 0)
I0621 01:01:15.470795 294 python_be.cc:2381] TRITONBACKEND_ModelInstanceInitialize: Whisper_0_7 (CPU device 0)
I0621 01:01:15.470843 294 python_be.cc:2381] TRITONBACKEND_ModelInstanceInitialize: Whisper_0_8 (CPU device 0)
I0621 01:01:15.470897 294 python_be.cc:2381] TRITONBACKEND_ModelInstanceInitialize: Whisper_0_9 (CPU device 0)
I0621 01:01:15.470938 294 python_be.cc:2381] TRITONBACKEND_ModelInstanceInitialize: Whisper_0_10 (CPU device 0)
I0621 01:01:15.471411 294 python_be.cc:2381] TRITONBACKEND_ModelInstanceInitialize: Whisper_0_11 (CPU device 0)
I0621 01:01:15.471528 294 python_be.cc:2381] TRITONBACKEND_ModelInstanceInitialize: Whisper_0_13 (CPU device 0)
I0621 01:01:15.471580 294 python_be.cc:2381] TRITONBACKEND_ModelInstanceInitialize: Whisper_0_14 (CPU device 0)
I0621 01:01:15.471652 294 python_be.cc:2381] TRITONBACKEND_ModelInstanceInitialize: Whisper_0_12 (CPU device 0)
I0621 01:01:15.471760 294 python_be.cc:2381] TRITONBACKEND_ModelInstanceInitialize: Whisper_0_15 (CPU device 0)
I0621 01:01:15.472498 294 python_be.cc:2381] TRITONBACKEND_ModelInstanceInitialize: Whisper_0_16 (CPU device 0)
I0621 01:01:15.473107 294 python_be.cc:2381] TRITONBACKEND_ModelInstanceInitialize: Whisper_0_17 (CPU device 0)
I0621 01:01:15.473518 294 python_be.cc:2381] TRITONBACKEND_ModelInstanceInitialize: Whisper_0_18 (CPU device 0)
I0621 01:01:15.474050 294 python_be.cc:2381] TRITONBACKEND_ModelInstanceInitialize: Whisper_0_19 (CPU device 0)
I0621 01:01:15.474312 294 python_be.cc:2381] TRITONBACKEND_ModelInstanceInitialize: Whisper_0_21 (CPU device 0)
I0621 01:01:15.474353 294 python_be.cc:2381] TRITONBACKEND_ModelInstanceInitialize: Whisper_0_20 (CPU device 0)
I0621 01:01:15.474447 294 python_be.cc:2381] TRITONBACKEND_ModelInstanceInitialize: Whisper_0_22 (CPU device 0)
I0621 01:01:15.476588 294 python_be.cc:2381] TRITONBACKEND_ModelInstanceInitialize: Whisper_0_23 (CPU device 0)
I0621 01:01:15.476716 294 python_be.cc:2381] TRITONBACKEND_ModelInstanceInitialize: Whisper_0_24 (CPU device 0)
I0621 01:01:15.476929 294 python_be.cc:2381] TRITONBACKEND_ModelInstanceInitialize: Whisper_0_25 (CPU device 0)
I0621 01:01:15.477036 294 python_be.cc:2381] TRITONBACKEND_ModelInstanceInitialize: Whisper_0_26 (CPU device 0)
I0621 01:01:15.477561 294 python_be.cc:2381] TRITONBACKEND_ModelInstanceInitialize: Whisper_0_27 (CPU device 0)
I0621 01:01:15.477747 294 python_be.cc:2381] TRITONBACKEND_ModelInstanceInitialize: Whisper_0_28 (CPU device 0)
I0621 01:01:15.478455 294 python_be.cc:2381] TRITONBACKEND_ModelInstanceInitialize: Whisper_0_29 (CPU device 0)
I0621 01:01:15.478707 294 python_be.cc:2381] TRITONBACKEND_ModelInstanceInitialize: Whisper_0_30 (CPU device 0)
I0621 01:01:15.479798 294 python_be.cc:2381] TRITONBACKEND_ModelInstanceInitialize: Whisper_0_31 (CPU device 0)
I0621 01:01:15.479908 294 python_be.cc:2381] TRITONBACKEND_ModelInstanceInitialize: Whisper_0_32 (CPU device 0)
I0621 01:01:15.488745 294 python_be.cc:2381] TRITONBACKEND_ModelInstanceInitialize: Whisper_0_33 (CPU device 0)
I0621 01:01:15.495002 294 python_be.cc:2381] TRITONBACKEND_ModelInstanceInitialize: Whisper_0_34 (CPU device 0)
I0621 01:01:15.499609 294 python_be.cc:2381] TRITONBACKEND_ModelInstanceInitialize: Whisper_0_35 (CPU device 0)
I0621 01:01:15.499609 294 python_be.cc:2381] TRITONBACKEND_ModelInstanceInitialize: Whisper_0_36 (CPU device 0)
2024-06-21 01:01:16,524 - INFO - pytriton: Serving inference
jaehyeong-bespin commented 1 week ago

I have same issue. Have you resolved this issue?