triton-inference-server / pytriton

PyTriton is a Flask/FastAPI-like interface that simplifies Triton's deployment in Python environments.
https://triton-inference-server.github.io/pytriton/
Apache License 2.0
719 stars 50 forks source link

How to pass priority level during inference? #42

Open jackielam918 opened 10 months ago

jackielam918 commented 10 months ago

I am confused on how to pass the priority value when performing inference.

For example, if I set up the DynamicBatcher with two priority levels:

batcher = DynamicBatcher(
            max_queue_delay_microseconds=1000
            preserve_ordering=False
            priority_levels=2,
            default_priority_level=2,
        )

When using the client and calling infer_batch or infer_sample where is priority supposed to passed? Looking at the docs I assumed that headers is where you would pass it, so I tried this:

client.infer_batch(..., headers={'priority': 1})

however that does not work. I couldn't find any examples or more detailed docs anywhere on how priority is supposed to be used. Any help would be appreciated.

piotrm-nvidia commented 10 months ago

Thank you for your question.

It seems like you're trying to set the priority for inference requests when using a DynamicBatcher in Triton Inference Server. From the code you posted, I noticed that you are attempting to set the priority using the headers argument. The Triton client API doesn't use headers for setting priorities.

To properly utilize priority levels in Triton, we'd need to add this functionality into the API in a different way. One possible approach would be to extend the client by creating a subclass that handles priority specifically. Here's an example of how we could create a new class named PriorityModelClient that extends ModelClient:

from pytriton.client import ModelClient

class PriorityModelClient(ModelClient):
    def __init__(self, url, model_name, priority=None):
        super().__init__(url, model_name)
        self._priority = priority

    def _get_infer_extra_args(self):
        extra_args = super()._get_infer_extra_args()
        if self._priority is not None:
            extra_args["priority"] = self._priority
        return extra_args

This custom class introduces an internal variable to store the priority level and overrides the _get_infer_extra_args method to add the priority to the extra arguments sent with the inference request.

For usage, you'd instantiate the PriorityModelClient with the priority level and then use its infer_batch method as usual:

pcl = PriorityModelClient("grpc://localhost", "Test", priority=1)
pcl.infer_batch(np.char.encode([["test"]], "utf-8"))

Could you provide more detail on your user scenario and how you intend to use these priority levels? Understanding the context of your application and its requirements would allow us to suggest a more tailored solution.

Furthermore, in terms of API extension, we aim for a straightforward and unobtrusive integration of priority handling. If the current approach is inconvenient or insufficient for your use case, we'd love to gather feedback on desired features or improvements. This information is valuable and can guide us towards making enhancements that better support real-world scenarios.

I have also written a test code to start a server for this client example. You can see it below:

import numpy as np
from pytriton.decorators import batch

@batch
def _infer_fn(text):
    return {"text":text}

from pytriton.model_config import ModelConfig, Tensor, DynamicBatcher
from pytriton.triton import Triton, TritonConfig

batcher = DynamicBatcher(
            max_queue_delay_microseconds=1000,
            preserve_ordering=False,
            priority_levels=2,
            default_priority_level=2,
        )

triton = Triton()
triton.bind(
    model_name="Test",
    infer_func=_infer_fn,
    inputs=[
        Tensor(name="text", dtype=bytes, shape=(-1,)),
    ],
    outputs=[
        Tensor(name="text", dtype=bytes, shape=(-1,)),
    ],
    config=ModelConfig(decoupled=False, max_batch_size=2, batcher=batcher),
)

triton.run()

I hope this helps you to understand how to use the priority feature of the DynamicBatcher. If you have any further questions, please let me know.

jackielam918 commented 10 months ago

Thanks for your response. The solution above looks like it should work for my use case. I will try it and get back to you.

As a followup, Is there a reason why priority is not part of the constructor? To me, it feels like it's a similar type of parameter as inference_timeout_s which is part of the constructor. Or alternatively have another parameter in the infer_* methods for supplying extra_args, similar to headers and parameters.

piotrm-nvidia commented 10 months ago

I will consider your request as feature proposal.