neuralmagic / deepsparse

Sparsity-aware deep learning inference runtime for CPUs
https://neuralmagic.com/deepsparse/
Other
2.97k stars 171 forks source link

[Text Generation][V2] NonKVCachePipeline #1483

Closed dbogunowicz closed 8 months ago

dbogunowicz commented 8 months ago

Feature Description

Added the TestGenerationPipelineNoKVCache. This pipeline processes the prompt and returns the new token. That's it. Its main functionality is mapping prompt tokens to logits, instrumental for computing the perplexity of the model given a dataset

Testing

Updated the integration tests to cover the case of non-kv-cache inference.

Example Use

from deepsparse.v2.text_generation import TextGenerationPipelineNoCache

prompt = ["Some funny prompt", "Why are you so"]

pipeline = TextGenerationPipelineNoCache(model_path="hf:mgoin/TinyStories-1M-ds",
                                         onnx_model_name="model-orig.onnx",
                                         sequence_length=20)

out = pipeline(prompt=prompt,
               include_prompt_logits=True,
               generation_kwargs=dict(output_scores=True))

for gen in out.generations:
    print(gen)
text='.' score=array([[ 2.9344807 , -0.03345669, -4.11256   , ..., -6.9316325 ,
        -4.6005425 ,  1.1827914 ],
       [ 7.008805  , -0.11603884, -7.1837015 , ..., -7.0405912 ,
        -2.386351  , -2.2007818 ],
       [ 6.348213  , -2.2960157 , -6.433192  , ..., -6.5930486 ,
        -5.8315077 , -0.58804405],
       [ 0.        ,  0.        ,  0.        , ...,  0.        ,
         0.        ,  0.        ]], dtype=float32) finished=True finished_reason='length' # notice that logits get zero padding from the end, this is because all logits need to have the same shape (the length of the longest prompt in the input +1)
text=' sad' score=array([[ 2.560934 ,  1.1993233, -6.670935 , ..., -7.3002615, -3.823823 ,
         1.8125833],
       [-1.1050931, -2.4256568, -7.3015127, ..., -6.1500154, -4.074909 ,
         1.8155754],
       [ 6.172593 , -2.2252593, -9.146653 , ..., -7.70834  , -4.810748 ,
         0.3985293],
       [ 1.4988875,  1.0973434, -4.4714937, ..., -4.8026247, -1.1791464,
         1.6924176]], dtype=float32) finished=True finished_reason='length'

Next steps