Open Anindyadeep opened 3 months ago
It needs to be tidied up for a PR, but here is a working example of using Triton with dspy
import numpy as np
import tritonclient.grpc as grpcclient
from tritonclient.utils import np_to_triton_dtype
import threading
from dsp.modules.lm import LM
class TritonLMClient(LM):
def __init__(self, url, model, **kwargs):
super().__init__(model,**kwargs)
self.client = grpcclient.InferenceServerClient(url=url)
self.model = model
self.kwargs = {
"temperature": 0.7, # Default temperature
"max_tokens": 2048, # Default max tokens
}
self.provider ='tensorrt_llm'
def generate(self, prompt, **kwargs):
# Update kwargs with any provided parameters
self.kwargs.update(kwargs)
result_event = threading.Event()
output_result = None
# Prepare the input
input_data = np.array([[prompt.encode()]], dtype=object) # Changed to 2D array
max_tokens = np.array([[self.kwargs['max_tokens']]], dtype=np.int32) # Already 2D array
temperature = np.array([[self.kwargs['temperature']]], dtype=np.float32) # Already 2D array
inputs = [
grpcclient.InferInput("text_input", input_data.shape, np_to_triton_dtype(input_data.dtype)),
grpcclient.InferInput("max_tokens", max_tokens.shape, np_to_triton_dtype(max_tokens.dtype)),
grpcclient.InferInput("temperature", temperature.shape, np_to_triton_dtype(temperature.dtype)),
]
inputs[0].set_data_from_numpy(input_data)
inputs[1].set_data_from_numpy(max_tokens)
inputs[2].set_data_from_numpy(temperature)
# Prepare the output
outputs = [
grpcclient.InferRequestedOutput("text_output")
]
def callback_function(result, error):
nonlocal output_result
if error:
raise Exception(f"Received error: {error}")
output = result.as_numpy("text_output")
if output.dtype.type is np.bytes_:
output_text = output.item().decode('utf-8')
else:
output_text = str(output)
output_result= output_text
result_event.set()
self.client.start_stream(callback=callback_function)
self.client.async_stream_infer(model_name=self.model, inputs=inputs, outputs=outputs)
result_event.wait(timeout=10)
self.client.stop_stream()
return output_result
def __call__(self, prompt, **kwargs):
return self.request(prompt, **kwargs)
def basic_request(self, prompt, **kwargs):
response= self.generate(prompt, **kwargs)
history = {
"prompt": prompt,
"response": response,
}
self.history.append(history)
return response
lm = TritonLMClient("172.22.93.97:5556", "tensorrt_llm_bls")
dspy.settings.configure(lm=lm)
I've no doubt this can be massively improved but i hope it can serve as a useful starting point
Note that this only supports single threading, so you need evaluate = Evaluate(num_threads=1,...)
. This is obviously a limitation of the current implementation, but so long as your GPU is saturated, it's fine enough.
Sorry I keep editing this as a gradually make it more functional
When it comes to fully production grade inference servers, TIS is very much optimized and open sourced. So an integration of this in dspy along with trt llm (#1094) would be some great additions.