pytorch / torchtitan

A native PyTorch Library for large model training
BSD 3-Clause "New" or "Revised" License
2.66k stars 206 forks source link

[BE] replace the extra DeviceMesh _flatten with mesh access #666

Closed XilunWu closed 3 weeks ago

XilunWu commented 3 weeks ago

Stack from ghstack (oldest at bottom):

Summary https://github.com/pytorch/pytorch/pull/138945 fixes DeviceMesh access on flattened mesh which are constructed from more than 2 meshes. Refer to the fix PR for details if interested.

In #592 we avoided this issue by calling _flatten instead of direct accessing the flattened mesh. We want to turn back to mesh access which is more straightforward since the fix has been merged in PyTorch.

XilunWu commented 3 weeks ago

It's better to have a try-except to indicate users are not using the latest PyTorch.

Oh yeah that's right...