alipay / PainlessInferenceAcceleration

Creative Commons Attribution 4.0 International
283 stars 18 forks source link

Painless Inference Acceleration (PIA)

A toolkit for LLM inference without 😭 . Currently it contains our work LOOKAHEAD, a framework which accelerates LLM inference without loss of accuracy, other works will release soon.

News or Update 🔥

Models we support

Known issuss & TODO

ISSUE 1. Repetition_penalty is not fully supported, we will fix it in the future.

ISSUE 2. Lookahead may generate responses different from original ones due to low-precise data type (i.e., fp16 or bf16), the responses would be the same with fp32.

ISSUE 3. Baichuan tokenizer cannot be initialized with the lastest version transformers (4.30.2 can work).

ISSUE 4. Qwen model may generate slightly different responses with lookahead when the repetition_penalty parameter is set.

TODO1: Support the latest version 🤗 transformers ]. Currently it's based on 4.30.2.

TODO2: Integrate our work FastCoT

TODO3: Optimize batch inference implementation with flash-attention.

Performance Comparison

Performance is measured by token/s(tokens per second) of generation tokens.

Public datasets and models

We use the test set for evaluation and the train set for trie-tree cache construction. The hyper-parameters are tuned by grid searching. The tag fused indicates operators are fused with triton, the implementation can be found in modeling_llama_batch.py.

model dataset GPU 🤗 transformers lookahead
Llama2-7b-chat Dolly-15k A100-80G 40.6 83.7 (x2.06)
Llama2-7b-chat GSM-8k A100-80G 41.4 111.3 (x2.69)
Llama2-7b-chat(fused) Dolly-15k A100-80G 50.4 106.8 (x2.12)
Llama2-7b-chat(fused) Dolly-15k A10 31.4 55.7(x1.77)
Llama2-7b-chat(fused) GSM-8k A100-80G 53.7 149.6 (x2.79)
Llama2-7b-chat(fused) GSM-8k A10 31.4 68.1(x2.17)
Llama2-7b-chat(fused) Humaneval-x A100-80G 51.1 161.5(x3.16)
Llama2-7b-chat(fused) Humaneval-x A10 30.9 89.6(x2.90)
Llama2-13b-chat Dolly-15k A100-80G 34.0 71.7 (x2.11)
Llama2-13b-chat GSM-8k A100-80G 31.2 71.1 (x2.28)
Llama2-13b-chat(fused) Dolly-15k A100-80G 39.9 84.6 (x2.12)
Llama2-13b-chat(fused) Dolly-15k V100-32G 20.5 35.2(x1.72)
Llama2-13b-chat(fused) GSM-8k A100-80G 42.9 103.4 (x2.41)
Llama2-13b-chat(fused) GSM-8k V100-32G 22.0 45.6(x2.07)
Llama2-13b-chat(fused) Humaneval-x A100-80G 35.0 137.3(x3.92)
Llama2-13b-chat(fused) Humaneval-x V100-32G 21.5 57.0(x2.65)
ChatGLM2-6b Dolly-15k A100-80G 45.6 108.4 (x2.38)
ChatGLM2-6b GSM-8k A100-80G 43.3 94.0 (x2.17)

We test 5 examples with Llama2-7b-chat and dolly dataset, inference time without lookahead (the left figure) is 15.7s (48.2token/s), while inference time with lookahead is 6.4s (112.9token/s), speedup is 2.34.

Private datasets and models

We use the first 1000 samples for evaluation and the rest for trie-tree cache construction. The hyper-parameters are decoding_length=128 and branch_lenght=32.

Our method could obtain significant acceleration in RAG (Retrieval Augmented Generation) scenarios. However, there is no real-life datasets available currently. Therefore, we only evaluate on our private datasets and models. AntGLM-10B is a LLM developed by Ant Group with GLM architecture.

model scenarios GPU 🤗 transformers Lookahead
AntGLM-10b Citizen Biz Agent A100-80G 52.4 280.9(x5.36)
AntGLM-10b Citizen Biz Agent A10 20.3 105.1(x5.18)
AntGLM-10b Citizen Biz Agent V100-32G 27.3 118.9(x4.36)
AntGLM-10b Enterprise Info QA A100-80G 50.7 259.1(x5.11)
AntGLM-10b Health Suggestion A100-80G 51.6 240.2(x4.66)

We test 5 examples with AntGLM-10B and AntRag dataset, inference time without lookahead (the left figure) is 16.9s (33.8token/s), while inference time with lookahead is 3.9s (147.6token/s), speedup is 4.37.

Introduction

Our repo PIA (short for Painless Inference Acceleration) is used for LLM inference, it is based on 🤗 transformers library.

Note that our work is different from the other method named lookahead decoding.

Lookahead workflow

workflow

Lookahead mask

mask

Trie construction

construction

Trie retrieve

retrieve

Lincense

CC BY 4.0 (https://creativecommons.org/licenses/by/4.0/)

Installation

  1. Clone this repository and navigate to PainlessInferenceAcceleration
    git clone https://github.com/alipay/PainlessInferenceAcceleration.git
    cd PainlessInferenceAcceleration
  2. Install Package
    python setup.py install

Quick Start

Below is an example for the simplest use of lookahead to inference:


import torch
from transformers import AutoTokenizer

from pia.lookahead.common.lookahead_cache import LookaheadCache
from pia.lookahead.models.llama.modeling_llama import LlamaForCausalLM

model_dir = 'meta-llama/Llama-2-7b-chat-hf'
model = LlamaForCausalLM.from_pretrained(model_dir
                                         , cache_dir='./'
                                         , torch_dtype=torch.float16
                                         , low_cpu_mem_usage=True
                                         , device_map='auto'
                                         )
tokenizer = AutoTokenizer.from_pretrained(model_dir)

prompt = "Hello, I'm am conscious and"
inputs = tokenizer(prompt, return_tensors="pt")

output_ids = model.generate(input_ids=inputs.input_ids.cuda(),
                            attention_mask=inputs.attention_mask.cuda(),
                            max_new_tokens=256,
                            decoding_kwargs={'use_lookahead': True}
                            )
response = tokenizer.decode(output_ids[0].tolist())
print(f'{response=}')

To use lookahead with other models, we can run the scripts in the path examples/. Each supported models are included and can be used for correctness evaluation.

python [model name]_example.py

Benchmarks

To evaluation speedup of lookahead, we can run the scripts in the path benchmarks/, the preprocess of datasets can be found in benchmarks/preprocess_sample.py.

To inspect running details of lookahead, we can turn on return_dict_in_generate, i.e.,

outputs = model.generate(...,
                        return_dict_in_generate=True
                        )
output_ids = outputs.sequences
kwargs = outputs.kwargs
# edls: short for effective decoding lengths, i.e., generate token count in a forward, therefore edls always >=1 ( even without lookahead, we will generate one token in a forward, so edls=1)
edls = kwargs['edls']
# dls: short of decoding lengths, i.e., token count in a forward, always >= 1. Note that it is set to 1 intead of prompt length in the prefill stage.
dls = kwargs['dls']
# fts: short for forward time(s), the first is the prefill time and others are decoding times.
fts = kwargs['fts']
#qts: short of query time(s), i.e., the time for retrieving a sub trie tree.
qts = kwargs['qts']

Customize Model

To support a customize model, usually we only need add a few lines, here is a example for supporting Llama: ```python from pia.lookahead.common.pretrained_model import LookaheadPreTrainedModel class LlamaPreTrainedModel(LookaheadPreTrainedModel): ''' other code ''' class LlamaModel(LlamaPreTrainedModel): ''' other code ''' def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: ''' other code ''' """ NOTE: adapt for lookahead lookahead always use a rank-4 tensor for attention_mask, then a minimum adaption for lookahead is routed by the rank, Lookahead: generate position_ids from attention_masks and set zero elements of the mask to -inf """ if attention_mask is not None and len(attention_mask.shape) == 4: # with lookahead position_ids = torch.sum(attention_mask, dim=-1).squeeze(1) - 1 attention_mask = (1.0-attention_mask.to(inputs_embeds.dtype)) * torch.finfo(inputs_embeds.dtype).min else: # without lookahead, reuse the original code lines if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device ) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() if attention_mask is None: attention_mask = torch.ones( (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device ) attention_mask = self._prepare_decoder_attention_mask( attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length ) ``` Note that the above adaption can not be used for batch inference, as generated token length of different samples may be varied. Adaption for batch inference can be found in `models/modeling_glm_batch.py` or `models/modeling_llama_batch.py`. `Flash-attention` enhanced batch inference is on developing.

Tests

Tests can be run with:

cd pia/lookahead
pytest tests/ -s

Citations

@misc{zhao2023lookahead, title={Lookahead: An Inference Acceleration Framework for Large Language Model with Lossless Generation Accuracy}, author={Yao Zhao and Zhitian Xie and Chenyi Zhuang and Jinjie Gu}, year={2023}, eprint={2312.12728}, archivePrefix={arXiv}, primaryClass={cs.IR} }