When training a sharded model with Flash Attention using segment_ids, the segment_ids are not sharded, resulting in a size mismatch. We attempted to resolve this by modifying custom_kernel.py (PR #8333), which successfully addresses the mismatch. However, with this fix, the loss does not converge to zero when training with dummy data; instead, it stalls at 0.2.
To Reproduce
Run any train using flash attention with segment_ids.
Expected behavior
Loss is expected to converge when using this fix with sharding training (with flash attention and segment_ids).
🐛 Bug
When training a sharded model with Flash Attention using segment_ids, the segment_ids are not sharded, resulting in a size mismatch. We attempted to resolve this by modifying custom_kernel.py (PR #8333), which successfully addresses the mismatch. However, with this fix, the loss does not converge to zero when training with dummy data; instead, it stalls at 0.2.
To Reproduce
Run any train using flash attention with segment_ids.
Expected behavior
Loss is expected to converge when using this fix with sharding training (with flash attention and segment_ids).
Environment