Closed Amshaker closed 1 year ago
Hi @Amshaker!
I haven't quite tested what you're asking about, but something very similar in scripts/benchmark_inference.py. In that script, I reproduce (a scaled-down version of) the inference benchmarks in the paper -- comparing RetNet
and nn.Transformer
at various input sequence lengths.
They claim that recurrent inference is O(1) and has higher throughput.
This is true. Recurrent inference has (roughly) $O(1)$ complexity per token generated. Although I've seen quite a few people confusing that with $O(1)$ complexity for the entire sequence. 😅 See this figure (bottom of README) that was generated with the script above.
I measure the throughput in the inference and it seems that the parallel method has more throughput compared to the recurrence method
Something to keep in mind -- a single forward pass with RetNet.forward_parallel
or RetNet.forward_recurrent
only generates one new token. For example, when using a sequence length of 4096
, the RetNet.parallel_forward
method does not generate 4096 new tokens at once. You have to pass all N-1
preceding tokens into the model in order to generate token N
. Contrast that with RetNet.forward_recurrent
, where the internal states are cached and reused for each successive token. Maybe that's where the throughput calculation is confusing?
Here's a script that (almost) measures what you're after. In the paper, they use a KV cache when measuring parallel Transformer
throughput, which speeds things up quite a bit. I didn't bother with KV cache for RetNet, because the recurrent formulation is already significantly faster.
import torch
from yet_another_retnet.retnet import RetNet
from yet_another_retnet.utils.benchmark import benchmark
num_tokens = 10000 # vocab size
batch_size = 4
seq_length = 2048
d_model = 512
nhead = 8
num_layers = 6
device = "cuda"
dtype = torch.float16
x = torch.randint(0, num_tokens, size=(batch_size, seq_length), device=device)
retnet = RetNet(
num_tokens=num_tokens,
d_model=d_model,
nhead=nhead,
num_layers=num_layers,
device=device,
dtype=dtype,
).eval()
with torch.no_grad():
recurrent = benchmark(retnet.forward_recurrent, x[:, 0], seq_idx=0)
print(f"RetNet recurrent: {batch_size / recurrent.mean:.3f} tokens/s")
parallel = benchmark(retnet.forward_parallel, x)
print(f"RetNet parallel: {batch_size / parallel.mean:.3f} tokens/s")
# Result on my machine with 2080 Ti GPU
# RetNet recurrent: 484.416 tokens/s
# RetNet parallel: 71.981 tokens/s
Thank you @fkodom so much for your prompt and comprehensive reply, I really appreciate it!
Are you sure that RetNet.forward_parallel
generate one token? It seems that the whole sequence is generated using the parallel method. Yes, RetNet.forward_recurrent
generates only one token.
Even in the readme here, you apply for loop for the recurrence method only to get the whole sequence and for the parallel method, there is no for loop.
I am measuring the throughput and time for the whole sequence. Now I am trying to measure the time of MultiScaleRetention
forward_parallel
and forward_recurrent
. According to my tracing, RetNet.forward_parallel
generates the whole sequence. Hence, I am measuring the time in this way:
import torch
import time
from yet_another_retnet.retention import MultiScaleRetention
mhr = MultiScaleRetention(embed_dim=32, num_heads=4, device="cuda").eval()
# input shape: (batch_size, seq_len, embed_dim)
q = k = v = torch.randn(2048, 4, 32, device="cuda")
# Parallel retention
start_time = time.time()
y_parallel, _ = mhr.forward_parallel(q, k, v)
end_time = time.time()
execution_time = end_time - start_time
print(f"Execution time parallel: {execution_time} seconds")
# Recursive retention
outputs = []
prev_state = None
start_time = time.time()
for idx in range(4):
out, prev_state = mhr.forward_recurrent(
q[:, idx], k[:, idx], v[:, idx], idx, prev_state
)
outputs.append(out)
y_recursive = torch.stack(outputs, dim=1)
end_time = time.time()
execution_time = end_time - start_time
print(f"Execution time recurrent: {execution_time} seconds")
# Check that outputs are equal
torch.testing.assert_close(y_parallel, y_recursive)
The output is :
Execution time parallel: 0.013367414474487305 seconds
Execution time recurrent: 0.007351875305175781 seconds
When I increase the sequence length from 4 to 48, the output is:
Execution time parallel: 0.09894466400146484 seconds
Execution time recurrent: 0.15374445915222168 seconds
So recurrence method is faster with a smaller sequence length. Please correct me if I am wrong. The whole point is measuring the time/throughput for the whole sequence not for each token and to confirm if forward_parallel
generate a single token or the whole sequence.
@Amshaker You're not entirely wrong. 🙃 To my understanding, here is the difference from what you described.
Suppose you're generating text with a RetNet model. So far, you've generated 1024 tokens. You do not know what the 1025th token is until you compute it with RetNet.
In order to compute token 1025, you could:
RetNet.forward_parallel
. The model is trained with labels offset by 1 index, so the last token returned is your predicted 1025th token.RetNet.forward_recurrent
while keeping track of the state Tensor. When passing the 1024th token, the model returns a prediction for token 1025.Then, you go to compute token 1026.
RetNet.forward_parallel
, so you're forced to pass the entire 1025-token sequence (or up to the allowed context window) into the model. The last token returned is your predicted 1026th token. RetNet.forward_recurrent
, you do have a state Tensor to reuse. So like before, you pass that into the model and get a prediction for token 1026.^^ Scenario (1) is how Transformer-based inference typically goes. That's why each forward pass only generates one token -- not 1024 at once. There are caching tricks to make it more efficient, which reduce the complexity from $O(N^2)$ (the complexity of naive multi-head attention) to $O(N)$. Because the recurrent RetNet
formulation maintains a running state Tensor, inference for each new token doesn't depend on the preceding sequence length, so it is $O(1)$.
Hi @fkodom ,
Thank you so much for sharing this work with the research community.
I have one question please, I measure the throughput in the inference and it seems that the parallel method has more throughput compared to the recurrence method, which is inconsistent with the paper. They claim that recurrent inference is O(1) and has higher throughput. Have you tested that or know what is the reason?
Best regards, Abdelrahman.