fkodom / yet-another-retnet

A simple but robust PyTorch implementation of RetNet from "Retentive Network: A Successor to Transformer for Large Language Models" (https://arxiv.org/pdf/2307.08621.pdf)
MIT License
100 stars 15 forks source link

Throughput measurements of parallel and recurrence methods #1

Closed Amshaker closed 1 year ago

Amshaker commented 1 year ago

Hi @fkodom ,

Thank you so much for sharing this work with the research community.

I have one question please, I measure the throughput in the inference and it seems that the parallel method has more throughput compared to the recurrence method, which is inconsistent with the paper. They claim that recurrent inference is O(1) and has higher throughput. Have you tested that or know what is the reason?

Best regards, Abdelrahman.

fkodom commented 1 year ago

Hi @Amshaker!

I haven't quite tested what you're asking about, but something very similar in scripts/benchmark_inference.py. In that script, I reproduce (a scaled-down version of) the inference benchmarks in the paper -- comparing RetNet and nn.Transformer at various input sequence lengths.

They claim that recurrent inference is O(1) and has higher throughput.

This is true. Recurrent inference has (roughly) $O(1)$ complexity per token generated. Although I've seen quite a few people confusing that with $O(1)$ complexity for the entire sequence. 😅 See this figure (bottom of README) that was generated with the script above.

Screen Shot 2023-08-08 at 8 45 15 AM

I measure the throughput in the inference and it seems that the parallel method has more throughput compared to the recurrence method

Something to keep in mind -- a single forward pass with RetNet.forward_parallel or RetNet.forward_recurrent only generates one new token. For example, when using a sequence length of 4096, the RetNet.parallel_forward method does not generate 4096 new tokens at once. You have to pass all N-1 preceding tokens into the model in order to generate token N. Contrast that with RetNet.forward_recurrent, where the internal states are cached and reused for each successive token. Maybe that's where the throughput calculation is confusing?

Here's a script that (almost) measures what you're after. In the paper, they use a KV cache when measuring parallel Transformer throughput, which speeds things up quite a bit. I didn't bother with KV cache for RetNet, because the recurrent formulation is already significantly faster.

import torch

from yet_another_retnet.retnet import RetNet
from yet_another_retnet.utils.benchmark import benchmark

num_tokens = 10000  # vocab size
batch_size = 4
seq_length = 2048
d_model = 512
nhead = 8
num_layers = 6

device = "cuda"
dtype = torch.float16

x = torch.randint(0, num_tokens, size=(batch_size, seq_length), device=device)
retnet = RetNet(
    num_tokens=num_tokens,
    d_model=d_model,
    nhead=nhead,
    num_layers=num_layers,
    device=device,
    dtype=dtype,
).eval()

with torch.no_grad():
    recurrent = benchmark(retnet.forward_recurrent, x[:, 0], seq_idx=0)
    print(f"RetNet recurrent: {batch_size / recurrent.mean:.3f} tokens/s")
    parallel = benchmark(retnet.forward_parallel, x)
    print(f"RetNet parallel: {batch_size / parallel.mean:.3f} tokens/s")

# Result on my machine with 2080 Ti GPU
# RetNet recurrent: 484.416 tokens/s
# RetNet parallel: 71.981 tokens/s
Amshaker commented 1 year ago

Thank you @fkodom so much for your prompt and comprehensive reply, I really appreciate it!

Are you sure that RetNet.forward_parallel generate one token? It seems that the whole sequence is generated using the parallel method. Yes, RetNet.forward_recurrent generates only one token.

Even in the readme here, you apply for loop for the recurrence method only to get the whole sequence and for the parallel method, there is no for loop.

image

I am measuring the throughput and time for the whole sequence. Now I am trying to measure the time of MultiScaleRetention forward_parallel and forward_recurrent. According to my tracing, RetNet.forward_parallel generates the whole sequence. Hence, I am measuring the time in this way:

import torch
import time
from yet_another_retnet.retention import MultiScaleRetention

mhr = MultiScaleRetention(embed_dim=32, num_heads=4, device="cuda").eval()

# input shape: (batch_size, seq_len, embed_dim)
q = k = v = torch.randn(2048, 4, 32, device="cuda")

# Parallel retention
start_time = time.time()
y_parallel, _ = mhr.forward_parallel(q, k, v)
end_time = time.time()
execution_time = end_time - start_time
print(f"Execution time parallel: {execution_time} seconds")

# Recursive retention
outputs = []
prev_state = None
start_time = time.time()
for idx in range(4):
    out, prev_state = mhr.forward_recurrent(
        q[:, idx], k[:, idx], v[:, idx], idx, prev_state
    )
    outputs.append(out)
y_recursive = torch.stack(outputs, dim=1)
end_time = time.time()

execution_time = end_time - start_time
print(f"Execution time recurrent: {execution_time} seconds")

# Check that outputs are equal
torch.testing.assert_close(y_parallel, y_recursive)

The output is :

Execution time parallel: 0.013367414474487305 seconds
Execution time recurrent: 0.007351875305175781 seconds

When I increase the sequence length from 4 to 48, the output is:

Execution time parallel: 0.09894466400146484 seconds
Execution time recurrent: 0.15374445915222168 seconds

So recurrence method is faster with a smaller sequence length. Please correct me if I am wrong. The whole point is measuring the time/throughput for the whole sequence not for each token and to confirm if forward_parallel generate a single token or the whole sequence.

fkodom commented 1 year ago

@Amshaker You're not entirely wrong. 🙃 To my understanding, here is the difference from what you described.

Suppose you're generating text with a RetNet model. So far, you've generated 1024 tokens. You do not know what the 1025th token is until you compute it with RetNet.

In order to compute token 1025, you could:

  1. Pass the entire 1024-token sequence into RetNet.forward_parallel. The model is trained with labels offset by 1 index, so the last token returned is your predicted 1025th token.
  2. Incrementally pass the first 1023 tokens into RetNet.forward_recurrent while keeping track of the state Tensor. When passing the 1024th token, the model returns a prediction for token 1025.

Then, you go to compute token 1026.

  1. You do not have a state Tensor to reuse with RetNet.forward_parallel, so you're forced to pass the entire 1025-token sequence (or up to the allowed context window) into the model. The last token returned is your predicted 1026th token.
  2. With RetNet.forward_recurrent, you do have a state Tensor to reuse. So like before, you pass that into the model and get a prediction for token 1026.

^^ Scenario (1) is how Transformer-based inference typically goes. That's why each forward pass only generates one token -- not 1024 at once. There are caching tricks to make it more efficient, which reduce the complexity from $O(N^2)$ (the complexity of naive multi-head attention) to $O(N)$. Because the recurrent RetNet formulation maintains a running state Tensor, inference for each new token doesn't depend on the preceding sequence length, so it is $O(1)$.