BoundaryML / baml

BAML is a language that helps you get structured data from LLMs, with the best DX possible. Check out the promptfiddle.com playground
https://docs.boundaryml.com
Apache License 2.0
1.01k stars 26 forks source link

Add batch prediction #705

Open 1feres1 opened 2 months ago

1feres1 commented 2 months ago

It will be great if baml works with some batch inference tools like vLLM or add its own

hellovai commented 2 months ago

that's a great idea. can you share some example usecase for how you'd like to do batching via baml specifically?

If you're doing batching, since BAML generates async clients for typescript / python you could just take any baml function and async/io them

import asyncio
from baml_client import b

def batch_call(params: List[str]):
   batched = await asyncio.gather([b.MyFunction(i) for i in params])

Though we don't do something more clever like only running at most N in parallel at at time, but i think thats very much a possibility for us.

Or did you mean something else?

1feres1 commented 2 months ago

my example use case is summarization of chats data (large data)

But I did"t find a way to make them both work together (very fast inference + correct json formatting) and both are very important

hellovai commented 2 months ago

It looks like vllm supports openai spec! Which means you should be able to use it with BAML.

Can you share how to you do batch inference with vLLM? Perhaps theres an easy way to do this. (ideally code / curl examples!)

1feres1 commented 2 months ago

hello, Hope you doing well,

You can find below the example of how I use VLLM for batch predictions for summary generation:

from huggingfacehub import login login(token= "hf") # change this with you hf tocken from vllm import LLM, SamplingParams sampling_params = SamplingParams(max_tokens=300) llm = LLM("mistralai/Mistral-7B-Instruct-v0.2", gpu_memory_utilization=1)

prompts = ['summarize this : hello world', "summarize this : agent set password for customer"] outputs = llm.generate(prompts, sampling_params) for output in outputs: summary = output.outputs[0].text print(summary) print('------------------------')