Closed 2015aroras closed 2 months ago
DeviceMesh is introduced in torch 2.2, and so trying to import it in torch 2.1 is caused the training code to break. This PR delays its import until it is needed (for hybrid sharding).
Fixes #559
DeviceMesh is introduced in torch 2.2, and so trying to import it in torch 2.1 is caused the training code to break. This PR delays its import until it is needed (for hybrid sharding).
Fixes #559