Open caojinpei opened 3 months ago
thank you for your interest of our project!
model, tokenizer = load()
# remember to replace the model with SqueezeAttention
prompt = get_prompt()
# get some data from any dataset
input_ids = tokenizer(prompt)
input_ids = input_ids.repeat(batch_size, 1).view(batch_size, -1)
# a naive way to increase batch_size, you can also make input in every batch different
start = time.time()
output = model.generate(input_ids, max_new_tokens=out_length, min_new_tokens=out_length, use_cache=True)
# we use max_new_tokens and min_new_tokens to control output length
end = time.time()
print('Time:{}s'.format((end-start)))
this is the main idea we test the improvement of throughput of SqueezeAttention, you need to implement it base on your own system. Because we have already demonstrated the accuracy in other part of experiments, in the part of throughput we only care about the input and output length, the max batch size with limited memory and the time it spends. we use output length * batch_size / time to represent throughput, as for the specific value of input / output length, please refer to out paper.
Hi, I have done as you said, but i found when I set batch_size>2, there is an error as below. Can you help me check it?
Traceback (most recent call last):
File "/home/sr5/yunfeng.gong/Work/LLM_Inference/SqueezeAttention/pred.py", line 355, in
Sorry for the late reply.
In the section Installation
of REAMDE, we need to run modify_transformers.py
to replace some files in transformers. Are you sure you ran this program?
In your paper, you said the max throughput improves about 2.2x. But how to reproduce it using your github codes? Can you give me some detailed instructions?