stanfordnlp / dspy

DSPy: The framework for programming—not prompting—foundation models
https://dspy-docs.vercel.app/
MIT License
16.75k stars 1.3k forks source link

Triton Inference Server #1095

Open Anindyadeep opened 3 months ago

Anindyadeep commented 3 months ago

When it comes to fully production grade inference servers, TIS is very much optimized and open sourced. So an integration of this in dspy along with trt llm (#1094) would be some great additions.

this-josh commented 1 week ago

It needs to be tidied up for a PR, but here is a working example of using Triton with dspy

import numpy as np
import tritonclient.grpc as grpcclient
from tritonclient.utils import np_to_triton_dtype
import threading
from dsp.modules.lm import LM

class TritonLMClient(LM):
    def __init__(self, url, model, **kwargs):
        super().__init__(model,**kwargs)
        self.client = grpcclient.InferenceServerClient(url=url)
        self.model = model
        self.kwargs = {
            "temperature": 0.7,  # Default temperature
            "max_tokens": 2048,  # Default max tokens
        }
        self.provider ='tensorrt_llm'

    def generate(self, prompt, **kwargs):
        # Update kwargs with any provided parameters
        self.kwargs.update(kwargs)
        result_event = threading.Event()
        output_result = None

        # Prepare the input
        input_data = np.array([[prompt.encode()]], dtype=object)  # Changed to 2D array
        max_tokens = np.array([[self.kwargs['max_tokens']]], dtype=np.int32)  # Already 2D array
        temperature = np.array([[self.kwargs['temperature']]], dtype=np.float32)  # Already 2D array
        inputs = [
            grpcclient.InferInput("text_input", input_data.shape, np_to_triton_dtype(input_data.dtype)),
            grpcclient.InferInput("max_tokens", max_tokens.shape, np_to_triton_dtype(max_tokens.dtype)),
            grpcclient.InferInput("temperature", temperature.shape, np_to_triton_dtype(temperature.dtype)),
        ]
        inputs[0].set_data_from_numpy(input_data)
        inputs[1].set_data_from_numpy(max_tokens)
        inputs[2].set_data_from_numpy(temperature)
        # Prepare the output
        outputs = [
            grpcclient.InferRequestedOutput("text_output")
        ]

        def callback_function(result, error):
            nonlocal output_result

            if error:
                raise Exception(f"Received error: {error}")

            output = result.as_numpy("text_output")
            if output.dtype.type is np.bytes_:
                output_text = output.item().decode('utf-8')
            else:
                output_text = str(output)
            output_result= output_text
            result_event.set()

        self.client.start_stream(callback=callback_function)

        self.client.async_stream_infer(model_name=self.model, inputs=inputs, outputs=outputs)
        result_event.wait(timeout=10)
        self.client.stop_stream()
        return output_result

    def __call__(self, prompt, **kwargs):
        return self.request(prompt, **kwargs)

    def basic_request(self, prompt, **kwargs):
        response= self.generate(prompt, **kwargs)
        history = {
            "prompt": prompt,
            "response": response,
        }
        self.history.append(history)
        return response

lm = TritonLMClient("172.22.93.97:5556", "tensorrt_llm_bls")
dspy.settings.configure(lm=lm)

I've no doubt this can be massively improved but i hope it can serve as a useful starting point

Note that this only supports single threading, so you need evaluate = Evaluate(num_threads=1,...). This is obviously a limitation of the current implementation, but so long as your GPU is saturated, it's fine enough.

Sorry I keep editing this as a gradually make it more functional