Summary:
For easier concat of multiple shards when we call DT.full_tensor() with LocalShardsWrapper. The most important case is checkpointing with state_dict or any case where we need global tensor of a CW sharded table from DTensor. This helps us avoid any extra logic in rearranging the shards when we checkpoint, we can do a simple concat on each rank.
Also add callbacks to MemoryBalancedEmbeddingShardingPlanner and HeteroEmbeddingShardingPlanner
Summary: For easier concat of multiple shards when we call DT.full_tensor() with LocalShardsWrapper. The most important case is checkpointing with state_dict or any case where we need global tensor of a CW sharded table from DTensor. This helps us avoid any extra logic in rearranging the shards when we checkpoint, we can do a simple concat on each rank.
Also add callbacks to MemoryBalancedEmbeddingShardingPlanner and HeteroEmbeddingShardingPlanner
Differential Revision: D59134513