huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
131.91k stars 26.27k forks source link

Performance Regression from commit 7dcd870 #22683

Closed fpgaminer closed 1 year ago

fpgaminer commented 1 year ago

System Info

Who can help?

@ArthurZucker @younesbelkada

Information

Tasks

Reproduction

I have a benchmark script which benchmarks the generation speed of different LLaMA models. Before commit 7dcd870 my generation speed averaged around 48 tokens/s in ideal cases, RTX 3090. After that commit the average speed is 43 tokens/s.

The specific issue seems to be the change to apply_rotary_pos_emb. My guess is the change from a rather simple slicing of two Tensors to a scatter-gather.

To test my theory I patched apply_rotary_pos_emb to its pre 7dcd870 state, and minimally modified LlamaAttention accordingly. No other modifications. Speed jumped back to 48 tokens/s.

The problem should apply generally, but the specific script I'm using is: https://github.com/fpgaminer/GPTQ-triton/blob/99ec4a3adb7fad9de33ff026bbfb64cbb3bab2f8/benchmark_generate.py

Expected behavior

I would not expect a 10% drop in performance.

sgugger commented 1 year ago

cc @gante and @ArthurZucker

gante commented 1 year ago

@fpgaminer commit 7dcd870 fixes generation when there is padding in the input (which is almost always the case for batch_size>1). It's natural that it introduces slowdowns, as the correct behavior implies changing to the tensor gathering you mentioned :)

We don't optimize for performance but rather for correctness. To skip this gathering while remaining correct, .generate() would need to be rewritten to dynamically squeeze padding and evict completed rows, which is something we have in our plans for the next months.

Meanwhile, is there anything else we can help you with?

fpgaminer commented 1 year ago

That's fair, though a 10% performance hit is rather painful.

To that end, here's my attempt to optimize apply_rotary_pos_emb:

def ref_apply_rotary_pos_emb(q, k, cos, sin, position_ids):
    gather_indices = position_ids[:, None, :, None]  # [bs, 1, seq_len, 1]
    gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
    cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
    sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

def fast_apply_rotary_pos_emb(q, k, cos, sin, position_ids):
    cos = cos.squeeze((0, 1))  # [seq_len, dim]
    sin = sin.squeeze((0, 1))  # [seq_len, dim]
    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

def test_foo(B, L):
    cos = torch.randn(1, 1, 2048, 128, dtype=torch.float16, device='cuda')
    sin = torch.randn(1, 1, 2048, 128, dtype=torch.float16, device='cuda')
    position_ids = torch.randint(0, 2048, (B, L), dtype=torch.int64, device='cuda')

    q = torch.randn(B, 32, L, 128, dtype=torch.float16, device='cuda')
    k = torch.randn(B, 32, L, 128, dtype=torch.float16, device='cuda')

    # Verify
    ref = ref_apply_rotary_pos_emb(q, k, cos, sin, position_ids)
    fast = fast_apply_rotary_pos_emb(q, k, cos, sin, position_ids)
    assert torch.equal(ref[0], fast[0])
    assert torch.equal(ref[1], fast[1])

    # Benchmark
    ref_ms, ref_min_ms, ref_max_ms = triton.testing.do_bench(lambda: ref_apply_rotary_pos_emb(q, k, cos, sin, position_ids))
    fast_ms, fast_min_ms, fast_max_ms = triton.testing.do_bench(lambda: fast_apply_rotary_pos_emb(q, k, cos, sin, position_ids))

    speedup = ref_ms * 100 / fast_ms
    print(f'{B} | {L:3d} |  {ref_ms:.6f} | {fast_ms:.6f} | {speedup:.2f}%')

print('B |  L  |    ref    |   fast   | speedup')
for B in [1, 2, 4, 8]:
    for L in [1, 2, 4, 8, 10, 100]:
        test_foo(B, L)

Output:

B |  L  |    ref    |   fast   | speedup
1 |   1 |  0.043008 | 0.035840 | 120.00%
1 |   2 |  0.044032 | 0.036864 | 119.44%
1 |   4 |  0.047104 | 0.038912 | 121.05%
1 |   8 |  0.046080 | 0.039936 | 115.38%
1 |  10 |  0.048128 | 0.039936 | 120.51%
1 | 100 |  0.058368 | 0.052224 | 111.76%
2 |   1 |  0.047104 | 0.036864 | 127.78%
2 |   2 |  0.049152 | 0.039936 | 123.08%
2 |   4 |  0.050176 | 0.040960 | 122.50%
2 |   8 |  0.050176 | 0.041984 | 119.51%
2 |  10 |  0.050176 | 0.041984 | 119.51%
2 | 100 |  0.079872 | 0.070656 | 113.04%
4 |   1 |  0.051200 | 0.039936 | 128.21%
4 |   2 |  0.053248 | 0.040960 | 130.00%
4 |   4 |  0.054272 | 0.041984 | 129.27%
4 |   8 |  0.057344 | 0.045056 | 127.27%
4 |  10 |  0.057344 | 0.045056 | 127.27%
4 | 100 |  0.130048 | 0.119808 | 108.55%
8 |   1 |  0.057344 | 0.040960 | 140.00%
8 |   2 |  0.059392 | 0.041984 | 141.46%
8 |   4 |  0.062464 | 0.045056 | 138.64%

For reference, the pre 7dc870 function runs in 0.030ms on 1x1, so this isn't quite as fast but gets closer.

Would a pull request with this change be welcome? I've done my best to verify its correctness with the above code.

gante commented 1 year ago

@fpgaminer that is great! Absolutely, a PR would be very welcome 🙌

(We'd be happy to integrate other optimization opportunities if you spot them, we rarely have the bandwidth to optimize our modeling code)

aljungberg commented 1 year ago

@fpgaminer commit 7dcd870 fixes generation when there is padding in the input (which is almost always the case for batch_size>1). It's natural that it introduces slowdowns, as the correct behavior implies changing to the tensor gathering you mentioned :)

Maybe there's something I'm not seeing here but Llama uses rotary positional embeddings so left padding should have no effect on the result?

Sure, the intermediate result from apply_rotary_pos_emb changes if you shift all tokens left or right, but the whole point of using relative embeddings is that they're invariant to the absolute position in terms of the final attention weight. So you can shift all tokens 50 positions to the right and the attention score between pairs of tokens will be the same, modulus any rounding errors.

Or are you saying there are cases when padding is literally inserted inside of the sequence, therefore changing the relative distances between tokens, @gante?

gante commented 1 year ago

@aljungberg I agree with everything you wrote, rotary positional embeddings should be position-invariant. In practice, the small rounding errors compound over autoregressive text generation, leading greedy decoding (which is normally invariant wrt small fluctuations) to produce different text.

With the right position index, the error becomes much smaller, and the results become more stable regardless of padding. That's why we also added it to our high-performance text generation repo, despite the difference being quite small.

Out of curiosity, this test was failing on GPTNeoX and Llama before we added this change. In theory, it shouldn't have failed at all!

github-actions[bot] commented 1 year ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.