AI-Hypercomputer / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.53k stars 293 forks source link

Switch Expert axis to avoid unnecessary copy for layout change #880

Closed ZhiyuLi-goog closed 2 months ago

ZhiyuLi-goog commented 2 months ago

Before: https://screenshot.googleplex.com/5LhbKL58gBwAGUr After: https://screenshot.googleplex.com/BFziQHuGGDuAjhN

0.5% to 0.8% improvement after borrowing sharding from paxml. +6% improvement in int8 with some better layout in backwards, still WIP in analyzing.

gobbleturk commented 2 months ago

LGTM! Please wait for @RissyRan approval as well

ZhiyuLi-goog commented 2 months ago

Need code owner's approval. Thank you @gobbleturk.