idiap / fast-transformers

Pytorch library for fast transformer implementations
1.65k stars 179 forks source link

Speed of recurrent model #116

Closed mads-oestergaard closed 2 years ago

mads-oestergaard commented 2 years ago

Hi,

I'm playing around with fast-transformers for use in the audio domain, and wanted to train a regular model with causal-linear attention and then evaluate it as a recurrent model. However, this quick inference speed test stroke me as odd:

from fast_transformers.builders import TransformerEncoderBuilder, \
    RecurrentEncoderBuilder
from fast_transformers.masking import TriangularCausalMask
from fast_transformers.utils import make_mirror

import time

d_model = 128
n_layers = 8
n_heads = 8

params = dict(attention_type="causal-linear",
    n_layers=n_layers,
    n_heads=n_heads,
    feed_forward_dimensions=512,
    query_dimensions=d_model // n_heads,
    value_dimensions=d_model // n_heads,
    )

# Build the models
model = TransformerEncoderBuilder.from_kwargs(**params).get()
recurrent_model = RecurrentEncoderBuilder.from_kwargs(**params).get()
make_mirror(model, recurrent_model)

x_in = torch.randn(2, 2000, 128)
mask = TriangularCausalMask(x_in.shape[1], device=x_in.device)

# Time the parallel model on the CPU
t0 = time.time()
with torch.no_grad():
    x_mask = model.forward(x_in, attn_mask=mask)
elapsed = time.time() - t0
print("Elapsed:", elapsed*1000, "ms")
# >> Elapsed: 162.17 ms

# Time the recurrent model on the CPU
recurrent_output = torch.zeros_like(x_in)

t0 = time.time()    
mem = None
for i in range(x_in.shape[1]):
    with torch.no_grad():
        xout, mem = recurrent_model.forward(x_in[:, i, :], state=mem)
    recurrent_output[:, i, :] = xout
elapsed = time.time() - t0
print("Elapsed:", elapsed*1000, "ms")
# >> Elapsed: 5431.31 ms

The inference speed of the recurrent model is a lot slower than I would have expected from reading the paper. Am I using the library as intended?

Ghadjeres commented 2 years ago

I guess you need to compare generation time using the recurrent model with the time to generate in an autoregressive way using the non-recurrent model (so 5431ms vs 2000 * 162ms).

mads-oestergaard commented 2 years ago

Yeah okay. In that case the recurrent model becomes slightly faster than the non-recurrent model:

x_in = torch.randn(2, 2000, 128)
mask = TriangularCausalMask(1, device=x_in.device)

t0 = time.time()
for i in range(x_in.shape[1]):
    with torch.no_grad():
        x_mask = model(x_in[:, i, :].unsqueeze(1), attn_mask=mask)
elapsed = time.time() - t0
print("Elapsed:", elapsed*1000, "ms")
# >> Elapsed: 6432.35 ms

so 6.432s vs 5.431s, so the recurrent model is 13.5% faster (on my CPU). Is that speedup within what you would expect?