AI-Hypercomputer / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.54k stars 295 forks source link

Update sharding annotation for dropping #1035

Closed RissyRan closed 1 week ago

RissyRan commented 1 week ago

Description

Add missing TP sharding for dropping weights (after initialization).