The operators, pipeline, and routes added for text_generation so far include the operators and the flow shown above
Testing
Testing script to evaluate various situations:
Just running the single-token engine as prompt_sequence_length > number of prompt tokens
Just running the multi-token engine as the number of tokens % prompt_sequence_length == 0
Running a combination of multi-token engine and single-token engine, as the number of prompt_sequence_length <= number of prompt tokens but prompt tokens % prompt_sequence_length != 0
All 3 cases are tested using the test script below, by changing the prompt_sequence_length.
Prompt logits are evaluated against the ground truth logits, produced using the transformers model.
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from deepsparse.transformers.pipelines.text_generation import TextGenerationInput
from deepsparse.v2.text_generation.pipeline import TextGenerationPipeline
from huggingface_hub import snapshot_download
def create_tokenizer(model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.padding_side = "left"
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
def get_ground_truth(prompt):
model = AutoModelForCausalLM.from_pretrained("roneneldan/TinyStories-1M")
tokenizer = create_tokenizer("roneneldan/TinyStories-1M")
input_ids = tokenizer.encode(prompt, return_tensors="pt")
out = model(input_ids=input_ids)
prompt_logits = out.logits.detach().numpy()
return prompt_logits
cur_len = 16
prompt = "Hello there, how are you?"
model_path = "hf:mgoin/TinyStories-1M-deepsparse"
pipeline = TextGenerationPipeline(model_path, engine_kwargs={"engine_type": "onnxruntime"}, prompt_sequence_length=cur_len)
input_values = TextGenerationInput(prompt=prompt)
logits = pipeline(input_values)
ground_truth = get_ground_truth(prompt)
print("All Close?", np.allclose(logits, ground_truth, atol=0.0001))
Summary
Testing
Testing script to evaluate various situations:
prompt_sequence_length
> number of prompt tokensprompt_sequence_length
== 0prompt_sequence_length
<= number of prompt tokens but prompt tokens %prompt_sequence_length
!= 0All 3 cases are tested using the test script below, by changing the
prompt_sequence_length
. Prompt logits are evaluated against the ground truth logits, produced using the transformers model.