ludwig-ai / ludwig

Low-code framework for building custom LLMs, neural networks, and other AI models
http://ludwig.ai
Apache License 2.0
11.19k stars 1.19k forks source link

Add support for prompt lookup decoding during generation #3917

Closed arnavgarg1 closed 9 months ago

arnavgarg1 commented 9 months ago

Implements support for Prompt Lookup Decoding by exposing a new generation config parameter called prompt_lookup_num_tokens. Compatible with transformer version >= 4.37.0.

In scenarios where the prompt is long and the output generated might re-use a lot of common ngrams, this can speedup token generation by near 2x - 2.4x. However, in scenarios where that may not be the case, such as open-ended questions, it leads to a 10% decrease in tokens per second.

Demo

https://drive.google.com/file/d/1E8qq8HnJBhL7GOuFDuMdih_GY1aAVwEC/view?usp=sharing

Script to Reproduce Demo

import yaml
import logging
from ludwig.api import LudwigModel

config = yaml.safe_load(
    """
model_type: llm
base_model: meta-llama/Llama-2-7b-chat-hf

quantization:
  bits: 4

input_features:
  - name: instruction
    type: text

output_features:
  - name: output
    type: text

generation:
  max_new_tokens: 64
  temperature: 0.1

trainer:
    type: none

backend:
  type: local
"""
)

model = LudwigModel(config, logging_level=logging.INFO)

code_text = """import numpy as np
import matplotlib.pyplot as plt

# Calculate the average
average_throughput = np.mean(tokens_per_sec_arr)
print(f"Average Throughput: {average_throughput} tokens/sec")

# Plotting the histogram
plt.hist(tokens_per_sec_arr, bins=20, color='blue', edgecolor='black', alpha=0.7)
plt.title('Histogram of Throughput Values')
plt.xlabel('Tokens per Second')
plt.ylabel('Frequency')
plt.axvline(average_throughput, color='red', linestyle='dashed', linewidth=1)
plt.text(average_throughput*0.9, max(plt.ylim())*0.9, f'Average: {average_throughput:.2f}', color = 'red')
plt.show()
"""

question = "Can you please change x axis to start from 0"

prompt = """<|im_start|>system
You are a helpful assistant<|im_end|>
<|im_start|>user
```python\n{code_text}``` \n\n{question}<|im_end|>
<|im_start|>assistant
""".format(code_text=code_text, question=question)

# Normal generation -> ~ 17s
output = model.generate(
    prompt, 
    generation_config={"max_new_tokens": 500}, 
    streaming=True
)

# With Prompt Lookup Decoding -> ~ 8.7s
output = model.generate(
    prompt, 
    generation_config={"max_new_tokens": 500, "prompt_lookup_num_tokens": 10}, 
    streaming=True
)
github-actions[bot] commented 9 months ago

Unit Test Results

  6 files  ±0    6 suites  ±0   14m 15s :stopwatch: -1s 12 tests ±0    9 :heavy_check_mark: ±0    3 :zzz: ±0  0 :x: ±0  60 runs  ±0  42 :heavy_check_mark: ±0  18 :zzz: ±0  0 :x: ±0 

Results for commit 218f58b6. ± Comparison against base commit 9bb89c6c.

:recycle: This comment has been updated with latest results.