Open samos123 opened 3 hours ago
Hi @samos123 , you can use the input dispatcher: https://github.com/apple/axlearn/blob/ac63eef8a76ee8e7fcb7e539ca1331e885ce286c/axlearn/common/input_tf_data.py#L1165-L1167 https://github.com/apple/axlearn/blob/ac63eef8a76ee8e7fcb7e539ca1331e885ce286c/axlearn/common/input_dispatch.py#L17-L33
Some hosts will produce padding feeds which will be dropped during input dispatch. I have some ideas to make this a bit simpler soon, but this should unblock you for now.
fsdp=16 model=16 global_batch_size=16 should work on 256 chips
The use case is being able to use a global batch size smaller than total jax processes.
This is supported in maxtext by using this trick: https://github.com/AI-Hypercomputer/maxtext/blob/4cf51b7f204e109df502cf2d54b4d5005f597b09/MaxText/train.py#L289-L291
Trying to get 405b model running on v6e-256 (fsdp=16 model=16) but getting hit with this error: