neuralmagic / deepsparse

Sparsity-aware deep learning inference runtime for CPUs
https://neuralmagic.com/deepsparse/
Other
2.94k stars 169 forks source link

[TextGeneration] Use appropriate future within `continuous_batching_scheduler` #1506

Closed dsikka closed 6 months ago

dsikka commented 6 months ago

Summary

Testing


num_workers: 1
num_streams: 1
endpoints:
  - task: text_generation
    model: "hf:mgoin/TinyStories-1M-ds"
    kwargs:
      {"continuous_batch_sizes": [4, 8, 16], "force_max_tokens": True}

Using the following code, we can see batch_size > 1 being used:


from openai import OpenAI
from threading import Thread
import time, argparse

import argparse

parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--num-threads', type=int, default=1)
parser.add_argument('--num-tokens', type=int, default=25)

_STREAM = False
_PRINT = True

def main(num_threads=1, num_tokens=25):
    openai_api_key = "EMPTY"
    openai_api_base = "http://localhost:5543/v1"

    client = OpenAI(
        # defaults to os.environ.get("OPENAI_API_KEY")
        api_key=openai_api_key,
        base_url=openai_api_base,
    )

    models = client.models.list()
    # model = models.data[0].id
    model = models.data[0][1]

    def run(idx, prompt="Mario jumped from the balcony"):
        print(f"launching thread {idx}")

        start = time.perf_counter()
        completion = client.completions.create(
            model=model,
            prompt=prompt,
            n=1,
            max_tokens=num_tokens,
            stream=_STREAM,
        )

        if _STREAM:
            for c in completion:
                print(c)
        else:
            if _PRINT:
                print(completion.choices[0].text)

        end = time.perf_counter()
        print(f"finished thread {idx} : {(num_tokens * num_threads) / (end - start): 0.5f}")

    ts = [Thread(target=run, args=[idx, "Mario jumped"]) for idx in range(num_threads)]

    for t in ts:
        t.start()
    for t in ts:
        t.join()

if __name__ == "__main__":
    args = parser.parse_args()
    main(num_threads=args.num_threads, num_tokens=args.num_tokens)
robertgshaw2-neuralmagic commented 6 months ago

nice work!