pytorch / torchrec

Pytorch domain library for recommendation systems
BSD 3-Clause "New" or "Revised" License
1.79k stars 381 forks source link

force CW shards to be contiguous #2192

Open iamzainhuda opened 1 week ago

iamzainhuda commented 1 week ago

Summary: For easier concat of multiple shards when we call DT.full_tensor() with LocalShardsWrapper. The most important case is checkpointing with state_dict or any case where we need global tensor of a CW sharded table from DTensor. This helps us avoid any extra logic in rearranging the shards when we checkpoint, we can do a simple concat on each rank.

Also add callbacks to MemoryBalancedEmbeddingShardingPlanner and HeteroEmbeddingShardingPlanner

Differential Revision: D59134513

facebook-github-bot commented 1 week ago

This pull request was exported from Phabricator. Differential Revision: D59134513