neuralmagic / deepsparse

Sparsity-aware deep learning inference runtime for CPUs
https://neuralmagic.com/deepsparse/
Other
3.01k stars 176 forks source link

[BugFix] Error with streaming in Chat Pipeline #1283

Closed rahul-tuli closed 1 year ago

rahul-tuli commented 1 year ago

This PR fixes a bug with ChatPipeline where streaming code would throw a TypeError, now the execution is as expected

Reproduction Code:

from deepsparse import Pipeline
import inspect

streaming = True
input_text = "def fib("

chat_pipeline = Pipeline.create(
    task="chat",
    model_path="/home/rahul/.cache/sparsezoo/neuralmagic/codegen_mono-350m-bigpython_bigquery_thepile-base/deployment",
)
predictions = chat_pipeline(**{"sequences": [input_text]}, streaming=streaming)

assert inspect.isgenerator(predictions), "Predictions should be a generator"

for i, prediction in enumerate(predictions):
    print(f"Prediction {i}: {prediction}")

Output:

Before this PR:

File
"/home/rahul/projects/deepsparse/src/deepsparse/pipeline.py", line 284, in __call__
    pipeline_outputs = self.process_engine_outputs(engine_outputs, **context)
  File "/home/rahul/projects/deepsparse/src/deepsparse/transformers/pipelines/chat.py", line 149, in process_engine_outputs
    return ChatOutput(**text_generation_output.dict(), session_ids=session_ids)
AttributeError: 'generator' object has no attribute 'dict'

After this PR:

Prediction 515: created=datetime.datetime(2023, 9, 26, 11, 27, 50, 625060) prompts=['def fib('] generations=[GeneratedText(text='10', score=None, finished=False, finished_reason=None)] session_ids=['540e329d-8f58-4875-9526-29dc1af9467b']
Prediction 516: created=datetime.datetime(2023, 9, 26, 11, 27, 50, 625120) prompts=['def fib('] generations=[GeneratedText(text='(', score=None, finished=False, finished_reason=None)] session_ids=['540e329d-8f58-4875-9526-29dc1af9467b']
Prediction 517: created=datetime.datetime(2023, 9, 26, 11, 27, 50, 625180) prompts=['def fib('] generations=[GeneratedText(text='n', score=None, finished=False, finished_reason=None)] session_ids=['540e329d-8f58-4875-9526-29dc1af9467b']
Prediction 518: created=datetime.datetime(2023, 9, 26, 11, 27, 50, 625240) prompts=['def fib('] generations=[GeneratedText(text='-', score=None, finished=False, finished_reason=None)] session_ids=['540e329d-8f58-4875-9526-29dc1af9467b']
Prediction 519: created=datetime.datetime(2023, 9, 26, 11, 27, 50, 625300) prompts=['def fib('] generations=[GeneratedText(text='2', score=None, finished=False, finished_reason=None)] session_ids=['540e329d-8f58-4875-9526-29dc1af9467b']
Prediction 520: created=datetime.datetime(2023, 9, 26, 11, 27, 50, 625365) prompts=['def fib('] generations=[GeneratedText(text=')', score=None, finished=False, finished_reason=None)] session_ids=['540e329d-8f58-4875-9526-29dc1af9467b']
Prediction 521: created=datetime.datetime(2023, 9, 26, 11, 27, 50, 625425) prompts=['def fib('] generations=[GeneratedText(text='\n', score=None, finished=False, finished_reason=None)] session_ids=['540e329d-8f58-4875-9526-29dc1af9467b']
Prediction 522: created=datetime.datetime(2023, 9, 26, 11, 27, 50, 625485) prompts=['def fib('] generations=[GeneratedText(text='\n', score=None, finished=False, finished_reason=None)] session_ids=['540e329d-8f58-4875-9526-29dc1af9467b']
Prediction 523: created=datetime.datetime(2023, 9, 26, 11, 27, 50, 625546) prompts=['def fib('] generations=[GeneratedText(text='def', score=None, finished=False, finished
.
.
.
.

Also added an automated test!