pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.5k stars 482 forks source link

Bug - Using Sharding in Flash Attention with segment ids. #8334

Open dudulightricks opened 1 month ago

dudulightricks commented 1 month ago

🐛 Bug

When training a sharded model with Flash Attention using segment_ids, the segment_ids are not sharded, resulting in a size mismatch. We attempted to resolve this by modifying custom_kernel.py (PR #8333), which successfully addresses the mismatch. However, with this fix, the loss does not converge to zero when training with dummy data; instead, it stalls at 0.2.

To Reproduce

Run any train using flash attention with segment_ids.

Expected behavior

Loss is expected to converge when using this fix with sharding training (with flash attention and segment_ids).

Environment