google / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.44k stars 263 forks source link

Allow Quantize KV Cache over Multiple Dimensions #708

Closed morgandu closed 2 months ago

morgandu commented 2 months ago

Checklist

This PR allows quantize kv cache over specified axis or axes

Setup

Results and Analysis

Annotation

Layout

KV Cache Quantization Axis/Axes

Summary

Accuracy

Exiting quantization over kv_dimenison, i.e. --kv_quant_axis=d

{'rouge1': 42.4362, 'rouge2': 20.1052, 'rougeL': 27.2529, 'rougeLsum': 39.9515, 'gen_len': 1133702, 'gen_num': 995}

New quantization over kv_heads and kv_dimenison, i.e. --kv_quant_axis=hd

{'rouge1': 41.9234, 'rouge2': 19.601, 'rougeL': 26.6612, 'rougeLsum': 39.5181, 'gen_len': 1156496, 'gen_num': 995}

Performance

New quantization over kv_heads and kv_dimenison, i.e. --kv_quant_axis=hd

New compute layout 0213
Cache layout 0213-0213

With default cache layout, i.e. same layout as compute

Successful requests: 995
Benchmark duration: 179.016534 s
Total input tokens: 217011
Total generated tokens: 979107
Request throughput: 5.56 requests/s
Input token throughput: 1212.24 tokens/s
Output token throughput: 5469.37 tokens/s

xprof: https://xprof.corp.google.com/overview_page/morgandu-15535915910542734436

Cache layout 0231-0213

The cache layout tuning gave us top cache layout 0231-0213, where the default cache layout 0213-0213 was ranked the 2nd by ar throughut.

Successful requests: 995
Benchmark duration: 167.951811 s
Total input tokens: 217011
Total generated tokens: 979107
Request throughput: 5.92 requests/s
Input token throughput: 1292.10 tokens/s
Output token throughput: 5829.69 tokens/s

xprof: https://xprof.corp.google.com/overview_page/morgandu-8809708365303398508

Performance Baselines and Past Improvement Highlights

Exiting quantization over kv_dimenison, i.e. --kv_quant_axis=d

Old compute layout 0123 and its history
Cache layout 1203-1203

With old compute layout 0123, old cache layout 1203-1203, and existing kv cache quantization, serving throughput was at 2488.511837 tokens/s.

Cache layout tuning gave us better top performant cache layout, though all of them were struggling between self-attention and cache update.

Cache layout 0231-0231

top layout, fast self-attention with slow cache update, yielding serving throughput 3333.375208 tokens/s, xprof: https://xprof.corp.google.com/overview_page/morgandu-17352860818423398257

Cache layout 0231-2130

top layout, slow self-attention with fast cache update, yielding serving throughput 3037.178915 tokens/s, xprof: https://xprof.corp.google.com/overview_page/morgandu-2974435458070151506

Reshaping Q with cache layout 0231-0213

Reshaping Q during ar forced compiler not broadcasting query to full sequence length and pushed compute to MXU, this bumped serving throughput to 3598.09 tokens/s, however, still suffer from long update with the optimal layout 0231-0213.

New compute layout 0213
Cache layout 0213-0213

New compute layout with default cache layout 0213-0213 yielding serving throughput 2112.00 tokens/s, xprof: https://xprof.corp.google.com/overview_page/morgandu-1040975859287826924

Cache layout 0231-0213

Layout tuning give us best cache layout 0231-0213, yielding 2985.70 tokens/s, xprof: https://xprof.corp.google.com/overview_page/morgandu-12972597741110718440

morgandu commented 2 months ago

Looks good, but there is lot of code duplicated from the compute_axis PR. Can you rebase from latest?

It was not ready for review, PR #667 just got merged.

Rebased to main and ready for review now.