hetailang / SqueezeAttention

36 stars 1 forks source link

How to reproduce the throughput improvement effect represented in the paper? #2

Open caojinpei opened 3 months ago

caojinpei commented 3 months ago

In your paper, you said the max throughput improves about 2.2x. But how to reproduce it using your github codes? Can you give me some detailed instructions?

hetailang commented 3 months ago

thank you for your interest of our project!

model, tokenizer = load()
# remember to replace the model with SqueezeAttention
prompt = get_prompt()
# get some data from any dataset
input_ids = tokenizer(prompt)
input_ids = input_ids.repeat(batch_size, 1).view(batch_size, -1)
# a naive way to increase batch_size, you can also make input in every batch different
start = time.time()
output = model.generate(input_ids, max_new_tokens=out_length, min_new_tokens=out_length, use_cache=True)
# we use max_new_tokens and min_new_tokens to control output length
end = time.time()
print('Time:{}s'.format((end-start)))

this is the main idea we test the improvement of throughput of SqueezeAttention, you need to implement it base on your own system. Because we have already demonstrated the accuracy in other part of experiments, in the part of throughput we only care about the input and output length, the max batch size with limited memory and the time it spends. we use output length * batch_size / time to represent throughput, as for the specific value of input / output length, please refer to out paper.

caojinpei commented 3 months ago

Hi, I have done as you said, but i found when I set batch_size>2, there is an error as below. Can you help me check it?

Traceback (most recent call last): File "/home/sr5/yunfeng.gong/Work/LLM_Inference/SqueezeAttention/pred.py", line 355, in preds = helm(model, tokenizer, data, args) File "/home/sr5/yunfeng.gong/Work/LLM_Inference/SqueezeAttention/pred.py", line 240, in helm output_sequences = model.generate( File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, kwargs) File "/home/sr5/yunfeng.gong/.local/lib/python3.10/site-packages/transformers/generation/utils.py", line 1758, in generate result = self._sample( File "/home/sr5/yunfeng.gong/.local/lib/python3.10/site-packages/transformers/generation/utils.py", line 2397, in _sample outputs = self( File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(args, kwargs) File "/home/sr5/yunfeng.gong/.local/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1035, in forward outputs = self.model( File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, kwargs) File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, *kwargs) File "/home/sr5/yunfeng.gong/Work/LLM_Inference/SqueezeAttention/utils_hh/modify_llama_drop.py", line 515, in forward layer_outputs, hidd_data = decoder_layer( File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, kwargs) File "/home/sr5/yunfeng.gong/Work/LLM_Inference/SqueezeAttention/utils_hh/modify_llama_drop.py", line 307, in forward hidden_states, self_attn_weights, present_key_value = self.self_attn( File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(args, kwargs) File "/home/sr5/yunfeng.gong/Work/LLM_Inference/SqueezeAttention/utils_hh/modify_llama_drop.py", line 220, in forward raise ValueError( ValueError: Attention mask should be of size (3, 1, 1, 2025), but is torch.Size([3, 1, 1, 2024])

hetailang commented 2 months ago

Sorry for the late reply. In the section Installation of REAMDE, we need to run modify_transformers.py to replace some files in transformers. Are you sure you ran this program?