microsoft / torchscale

Foundation Architecture for (M)LLMs
https://aka.ms/GeneralAI
MIT License
3.01k stars 202 forks source link

Is there some example of the paper? e.g., compare of the inference latency #53

Closed LiZeng001 closed 1 year ago

LiZeng001 commented 1 year ago

Hi, Thank you for your great work! We are intersted in the ability of RetNet. However, when we look through this repository, we can't find the code correspond to the paper's experiments. For example, an example of generate a same length text via RetNet and Transformer-based LLM (such as Llama-7b) of similar size, to compare the Inference Latency, an example of long sequence inference, and so on.

So, can you provide some basic example code of train/inference, which compare the RetNet and the Transformer-based LLM, without the Fairseq?

Any response will be great helpful for us!

sunyt32 commented 1 year ago

Yeah, the experiments are based on our private data and pipelines, which are not appropriate for open-source. Following our guidelines in README, RetNet will be easy to integrate into your own training procedure without Fairseq ( the import method is identical to Transformer). For inference speed, our experiments use incremental_state for auto-regressive decoding. Here is a pseudo example:

incremental_state = {}
net_input = torch.rand(bsz, tgt_len)
for index in range(net_input.shape[1]):
    generation_net_input = net_input[:, :(index + 1)]
    generation_net_output, _ = model(generation_net_input, incremental_state=incremental_state)
    net_input[:, index + 1] = torch.argmax(generation_net_output[0], dim=-1)

In every step, incremental_state stores the past state (k/v cache for Transformer, kv state for RetNet). You can modify it on your own inference pipeline.