google / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.39k stars 247 forks source link

Allow Different Compute Layout for Attention #709

Closed morgandu closed 1 month ago

morgandu commented 1 month ago

Checklist

This PR introduced compute layout control to allowed a different compute layout for attention

Setup

Results and Analysis

The goal of introducing the new compute layout is to potentially avoid cache layout tuning, though we still can tune the cache layout to seek and verify for the best performance.

Annotation

Layout

Summary

Existing attention compute layout is 0123, and we introduced a different compute layout 0213, which is of a layout that's TPU friendly.

We introduced 0213 compute layout to verify:

Performance

Existing compute layout 0123 and its history

Cache layout 1203-1203

With the existing cache layout was 1203-1203, with throughput 2591.642232 tokens/s, this was improved from the default cache layout 0123-0123 about 3x.

Cache layout 2013-2013

After layout tuning, we got optimal prefill-ar cache layout as 2013-2013, with throughput 3347.180221 tokens/s, which was 29% improvement.

New compute layout 0213

Cache layout 0213-0213

With the two cache in the same layout as compute, i.e. 0213-0213 (xprof: https://xprof.corp.google.com/overview_page/morgandu-12159058496322304249), we got 3273.96 tokens/s, this is of the top performance after we verified with layout tuning.

Cache layout 0213-0132

The tuned cache layout that give us the best throughput 3329.45 tokens/s is 0213-0132 (xprof: https://xprof.corp.google.com/overview_page/morgandu-5743582688063478644)

Accuracy

No regression on Rouge scores between 0123 and 0213

{'rouge1': 42.1738, 'rouge2': 19.6973, 'rougeL': 26.9088, 'rougeLsum': 39.6794, 'gen_len': 1144204, 'gen_num': 995}