syncdoth / RetNet

Huggingface compatible implementation of RetNet (Retentive Networks, https://arxiv.org/pdf/2307.08621.pdf) including parallel, recurrent, and chunkwise forward.
MIT License
226 stars 24 forks source link

Question about verifying the Inference Latency #8

Closed LiZeng001 closed 1 year ago

LiZeng001 commented 1 year ago

Hi, Thank you for your great work!

When I use your example code to compare the Inference Latency with Transformer-based LLM, the result is not as expected in the paper (15.6X). Could you please give some help?

I run a same length Generation Task on RetNet-6.7B and a similar size model (chinese_llama_7b), below are the code and results, (1) RetNet-6.7B

import torch import time from retnet.modeling_retnet import RetNetModelWithLMHead from retnet.configuration_retnet import load_config_from_yaml from transformers import AutoTokenizer

config = load_config_from_yaml('configs/retnet-6.7b.yml') model = RetNetModelWithLMHead(config).to("cuda")

tokenizer = AutoTokenizer.from_pretrained("gpt2") tokenizer.model_max_length = 2048 tokenizer.pad_token = tokenizer.eos_token

inputs = tokenizer("Retention refers to", return_tensors='pt').to("cuda") start_time = time.time() generated = model.generate(*inputs, parallel_compute_prompt=True, max_new_tokens=100, eos_token_id=tokenizer.encode("\n")[0],) end_time = time.time() print("time cost:", int(float(end_time - start_time) 1000.0), "ms") print(generated.shape) print(tokenizer.batch_decode(generated))

RetNet-6.7B Latency: time cost: 4929 ms torch.Size([1, 100]) ['Japan Muslimcue Developers absence finals rundownutenberg agre Reefaqurr certifyprice swap Lupidity431 acutelyprice Monthly Loading heftytri tengpnatdemocratic Rating Dayamping unc Future Townsenv ArnoldJP Crusader despite axapple Warriors else Borderlands collaborativeios mounted lingerusp Greene Celtics PHOTO McCainagusReallyParam < limitations detect fixesundrum isot towering there proclaimed503 CM white 1979 schizophren teamongo PearsonMurray shadow NEXT activated GunogenouspriceIndust pedestrian scattering locnaiamelessintelligence resorted SenegalHall Keystoneroute wrap850Bernie divert Sadd witch Completed Sly']

======================================================================== (2) chinese_llama_7b

import torch import time from transformers import GPTNeoXForCausalLM, AutoTokenizer, PreTrainedTokenizerFast, AutoModelForCausalLM

model_path = "/pretrain-models/llama/chinese_llama_lora_7b_merge" gpu_available = torch.cuda.is_available() print("gpu: ", gpu_available) model = AutoModelForCausalLM.from_pretrained(model_path).to("cuda") tokenizer = AutoTokenizer.from_pretrained("gpt2") tokenizer.model_max_length = 2048

inputs = tokenizer("Retention refers to", return_tensors="pt").to("cuda") start_time = time.time() tokens = model.generate( **inputs, temperature=0.9, max_new_tokens=100, eos_token_id=tokenizer.encode("\n")[0], )

end_time = time.time() print("time cost:", int(float(end_time - start_time) * 1000.0), "ms") print(tokens.shape) print(tokenizer.batch_decode(tokens))

chinese_llama_7b Latency: time cost: 2280 ms torch.Size([1, 104]) ['Retention refers to mortar invade diarrFri Lak Atmosp guiActiveUnfocused CEOs402whel coated" mortar Atmosp guiActiveUnfocused CEOs402whel coated" mortar Atmosp guiActiveUnfocused CEOs402whel coated" mortar Atmosp guiActiveUnfocused CEOs402whel coated" mortar Atmosp guiActiveUnfocused CEOs402whel coated" mortar Atmosp guiActiveUnfocused CEOs402whel coated" mortar Atmosp guiActiveUnfocused CEOs402whel coated" mortar Atmosp guiActiveUnfocused CEOs402whel coated" mortar Atmosp guiActiveUnfocused CEOs402whel coated" mortar Atmosp guiActiveUnfocused CEOs402whel coated" mortar Atmosp guiActiveUnfocused CEOs402whel coated" mortar Atmosp guiActiveUnfocused CEOs402whel coated"']

=======================================================================

As the result, to generate a same length text with the same input, RetNet-6.7B cost 4929 ms, while chinese_llama_7b cost 2280 ms, which is not consistent with the Paper. Is my test method wrong, or something else? Could you please give some help, and any response will be great helpful for me!

jploski commented 1 year ago

What you are testing using the "generate" method above is the "parallel" forward pass. I believe to obtain the O(1) inference performance you need to generate tokens using the recurrent method (or using the chunked method, which is a tradeoff between recurrent and parallel).

Also note that this GitHub project is not the official Microsoft implementation referenced from the paper, so there might be some differences in performance.

syncdoth commented 1 year ago

Thanks @jploski for explaining!

Indeed, this is unofficial project, and if you want the original code, checkout the torchscale repo, which has the official code for retnet now. I have plans to update the implementation based on that. (Always open to contributions too!)

One clarification tho, the generate function uses recurrent method by default. (Parallel for prompt -> obtain KV cache -> generate using recurrent) I’m not sure what the speed bottleneck is in my code, but I suspect it can be the parallel compute for the prompt part? We’ll have to investigate further.

In theory, as the context gets longer, each token decoding time should stay the same for retnet, but increasingly longer for transformers.

jploski commented 1 year ago

Thanks, I see it now. I was confused by the "parallel forward" comment in the Language Generation section of README.md and the default values "parallel" for forward_impl in code, along with the specific examples of doing stepwise recurrent/chunked inference.