Closed LiZeng001 closed 1 year ago
What you are testing using the "generate" method above is the "parallel" forward pass. I believe to obtain the O(1) inference performance you need to generate tokens using the recurrent method (or using the chunked method, which is a tradeoff between recurrent and parallel).
Also note that this GitHub project is not the official Microsoft implementation referenced from the paper, so there might be some differences in performance.
Thanks @jploski for explaining!
Indeed, this is unofficial project, and if you want the original code, checkout the torchscale repo, which has the official code for retnet now. I have plans to update the implementation based on that. (Always open to contributions too!)
One clarification tho, the generate
function uses recurrent method by default. (Parallel for prompt -> obtain KV cache -> generate using recurrent) I’m not sure what the speed bottleneck is in my code, but I suspect it can be the parallel compute for the prompt part? We’ll have to investigate further.
In theory, as the context gets longer, each token decoding time should stay the same for retnet, but increasingly longer for transformers.
Thanks, I see it now. I was confused by the "parallel forward" comment in the Language Generation section of README.md and the default values "parallel" for forward_impl in code, along with the specific examples of doing stepwise recurrent/chunked inference.
Hi, Thank you for your great work!
When I use your example code to compare the Inference Latency with Transformer-based LLM, the result is not as expected in the paper (15.6X). Could you please give some help?
I run a same length Generation Task on RetNet-6.7B and a similar size model (chinese_llama_7b), below are the code and results, (1) RetNet-6.7B
import torch import time from retnet.modeling_retnet import RetNetModelWithLMHead from retnet.configuration_retnet import load_config_from_yaml from transformers import AutoTokenizer
config = load_config_from_yaml('configs/retnet-6.7b.yml') model = RetNetModelWithLMHead(config).to("cuda")
tokenizer = AutoTokenizer.from_pretrained("gpt2") tokenizer.model_max_length = 2048 tokenizer.pad_token = tokenizer.eos_token
inputs = tokenizer("Retention refers to", return_tensors='pt').to("cuda") start_time = time.time() generated = model.generate(*inputs, parallel_compute_prompt=True, max_new_tokens=100, eos_token_id=tokenizer.encode("\n")[0],) end_time = time.time() print("time cost:", int(float(end_time - start_time) 1000.0), "ms") print(generated.shape) print(tokenizer.batch_decode(generated))
RetNet-6.7B Latency: time cost: 4929 ms torch.Size([1, 100]) ['Japan Muslimcue Developers absence finals rundownutenberg agre Reefaqurr certifyprice swap Lupidity431 acutelyprice Monthly Loading heftytri tengpnatdemocratic Rating Dayamping unc Future Townsenv ArnoldJP Crusader despite axapple Warriors else Borderlands collaborativeios mounted lingerusp Greene Celtics PHOTO McCainagusReallyParam < limitations detect fixesundrum isot towering there proclaimed503 CM white 1979 schizophren teamongo PearsonMurray shadow NEXT activated GunogenouspriceIndust pedestrian scattering locnaiamelessintelligence resorted SenegalHall Keystoneroute wrap850Bernie divert Sadd witch Completed Sly']
======================================================================== (2) chinese_llama_7b
import torch import time from transformers import GPTNeoXForCausalLM, AutoTokenizer, PreTrainedTokenizerFast, AutoModelForCausalLM
model_path = "/pretrain-models/llama/chinese_llama_lora_7b_merge" gpu_available = torch.cuda.is_available() print("gpu: ", gpu_available) model = AutoModelForCausalLM.from_pretrained(model_path).to("cuda") tokenizer = AutoTokenizer.from_pretrained("gpt2") tokenizer.model_max_length = 2048
inputs = tokenizer("Retention refers to", return_tensors="pt").to("cuda") start_time = time.time() tokens = model.generate( **inputs, temperature=0.9, max_new_tokens=100, eos_token_id=tokenizer.encode("\n")[0], )
end_time = time.time() print("time cost:", int(float(end_time - start_time) * 1000.0), "ms") print(tokens.shape) print(tokenizer.batch_decode(tokens))
chinese_llama_7b Latency: time cost: 2280 ms torch.Size([1, 104]) ['Retention refers to mortar invade diarrFri Lak Atmosp guiActiveUnfocused CEOs402whel coated" mortar Atmosp guiActiveUnfocused CEOs402whel coated" mortar Atmosp guiActiveUnfocused CEOs402whel coated" mortar Atmosp guiActiveUnfocused CEOs402whel coated" mortar Atmosp guiActiveUnfocused CEOs402whel coated" mortar Atmosp guiActiveUnfocused CEOs402whel coated" mortar Atmosp guiActiveUnfocused CEOs402whel coated" mortar Atmosp guiActiveUnfocused CEOs402whel coated" mortar Atmosp guiActiveUnfocused CEOs402whel coated" mortar Atmosp guiActiveUnfocused CEOs402whel coated" mortar Atmosp guiActiveUnfocused CEOs402whel coated" mortar Atmosp guiActiveUnfocused CEOs402whel coated"']
=======================================================================
As the result, to generate a same length text with the same input, RetNet-6.7B cost 4929 ms, while chinese_llama_7b cost 2280 ms, which is not consistent with the Paper. Is my test method wrong, or something else? Could you please give some help, and any response will be great helpful for me!