google / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.39k stars 247 forks source link

Support partial overrides for logical_axis_rules. #730

Closed golechwierowicz closed 2 weeks ago

golechwierowicz commented 3 weeks ago

Sometimes we want to have different shardings for particular models. They change quite frequently in base.yml file and require updating in concrete files. This diff adds support for partial overrides e.g. overriding only logical axis of interest, not the entire rule set. The usage is best demonstrated in thy yml files included in the diff.

golechwierowicz commented 3 weeks ago

Changes applied, PTL. I'm not sure where does copybara error comes from though.

gobbleturk commented 3 weeks ago

It looks like you are missing the copyright header on the new files. https://screenshot.googleplex.com/5SsDGDimrKsSUHf

E.g. you need this in the new files you are adding