AI-Hypercomputer / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.54k stars 295 forks source link

[Inference Perf] Add autotuned xla flags to improve latency for v6e #1031

Closed wyzhang closed 1 week ago

wyzhang commented 1 week ago

Reason: Use auto tune to find a better set of xla flag settings that can improve inference latency. It gives ~10% latency reduction for generate step and doesn't hurt prefill.

Microbenchmark result BEFORE: Prefill benchmark results for length 1024: Per prefill step per device: Prefill step average time: 60.695 ms <<<<<<<<<<<<<<<<<<<<<<<<<< Prefill and insert benchmark results for length 1024: Prefill + Insert step average time: 63.350 ms AutoRegressive results: AR step average time: 63.547 ms <<<<<<<<<<<<<<<<<<<<<<<<<<<<<

AFTER: Prefill benchmark results for length 1024: Per prefill step per device: Prefill step average time: 60.451 ms<<<<<<<<<<<<<<<<<<<<<<<<<<<< 0.5% better Prefill and insert benchmark results for length 1024: Prefill + Insert step average time: 62.944 ms AutoRegressive results: AR step average time: 57.327 ms<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< 9.8% better