neuralmagic / deepsparse

Sparsity-aware deep learning inference runtime for CPUs
https://neuralmagic.com/deepsparse/
Other
2.94k stars 169 forks source link

[server][Pipeline] fix async future being used by continuous_batching #1535

Closed dsikka closed 6 months ago

dsikka commented 6 months ago

Summary

Inspiration: https://github.com/python/cpython/blob/b331381485c1965d1c88b7aee7ae9604aca05758/Lib/asyncio/base_events.py#L871

Testing

Server (using async pathway)

num_workers: 1
endpoints:
  - task: text_generation
    model: "hf:neuralmagic/MiniChat-1.5-3B-pruned50-quant-ds"
    kwargs:
      {"continuous_batch_sizes": [2,3,4,5,6,7,8,16], "force_max_tokens": True}

Client


import requests
from threading import Thread
import time, argparse

import argparse

parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--num-threads', type=int, default=1)
parser.add_argument('--num-tokens', type=int, default=25)

_STREAM = False
_PRINT = False

def main(num_threads=1, num_tokens=25):
    url = "http://localhost:5543/v2/models/text_generation-0/infer"

    def run(idx, prompt="Mario jumped"):
        print(f"launching thread {idx}")

        num_return_sequences = 1
        start = time.perf_counter()
        obj = {
            "prompt": prompt,
            "generation_kwargs": {
                "max_length": num_tokens
            },
            "num_return_sequences": num_return_sequences
        }

        response = requests.post(url, json=obj)

        if _STREAM:
            for c in completion:
                print(c)
        else:
            if _PRINT:
                print(response._content)

        end = time.perf_counter()
        print(f"finished thread {idx} : {(num_tokens * num_threads * num_return_sequences) / (end - start): 0.5f}")

    ts = [Thread(target=run, args=[idx, "Mario jumped"]) for idx in range(num_threads)]

    for t in ts:
        t.start()
    for t in ts:
        t.join()

if __name__ == "__main__":
    args = parser.parse_args()
    main(num_threads=args.num_threads, num_tokens=args.num_tokens)

Server Numbers:

Threads max_tokens tokens/second
4 500 100 tokens/second
8 500 155 tokens/second
16 500 210 tokens/second

Default pipeline.run(...)/non-async pathway

from deepsparse import Pipeline
from threading import Thread
import time, argparse

import argparse

parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--num-threads', type=int, default=1)
parser.add_argument('--num-tokens', type=int, default=25)

_STREAM = False
_PRINT = False

def main(num_threads=1, num_tokens=25):

    pipeline = Pipeline.create(
        task="text_generation",
        model_path="hf:neuralmagic/MiniChat-1.5-3B-pruned50-quant-ds",
        continuous_batch_sizes=[2,3,4,5,6,7,8,16],
        force_max_tokens=True
    )

    def run(idx, prompt="Mario jumped"):
        print(f"launching thread {idx}")

        num_return_sequences = 1
        start = time.perf_counter()
        completion = pipeline(
            prompt=prompt,
            max_length=num_tokens,
            num_return_sequences=num_return_sequences
        )

        if _STREAM:
            for c in completion:
                print(c)
        else:
            if _PRINT:
                print(completion)

        end = time.perf_counter()
        print(f"finished thread {idx} : {(num_tokens * num_threads * num_return_sequences) / (end - start): 0.5f}")

    ts = [Thread(target=run, args=[idx, "Mario jumped"]) for idx in range(num_threads)]

    for t in ts:
        t.start()
    for t in ts:
        t.join()

if __name__ == "__main__":
    args = parser.parse_args()
    main(num_threads=args.num_threads, num_tokens=args.num_tokens)

Non-async Numbers:

Threads max_tokens tokens/second
4 500 97 tokens/second
8 500 148 tokens/second
16 500 207 tokens/second