Describe the bug
Hi torchrec team, I found out that if the keys of an input bag were across multiple devices, the mean pooling result was incorrect. The root cause was that fbgemm will divide the embedding results by local bag size , and the output_dist of torchrec is a SUM reduce scatter (RW sharding).
Replicating steps
scripts:
import os
import sys
import torch
import torchrec
import torch.distributed as dist
from torchrec.distributed.fbgemm_qcomm_codec import get_qcomm_codecs_registry, QCommsConfig, CommType
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology, ParameterConstraints
from torchrec.distributed.embedding import EmbeddingCollectionSharder
from torchrec.distributed.types import (
ModuleSharder,
ShardingType,
)
dist.init_process_group(backend="nccl")
Describe the bug Hi torchrec team, I found out that if the keys of an input bag were across multiple devices, the mean pooling result was incorrect. The root cause was that fbgemm will divide the embedding results by
local bag size
, and theoutput_dist
of torchrec is aSUM
reduce scatter (RW sharding).Replicating steps
import os import sys import torch import torchrec import torch.distributed as dist from torchrec.distributed.fbgemm_qcomm_codec import get_qcomm_codecs_registry, QCommsConfig, CommType from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology, ParameterConstraints from torchrec.distributed.embedding import EmbeddingCollectionSharder from torchrec.distributed.types import ( ModuleSharder, ShardingType, ) dist.init_process_group(backend="nccl")
local_rank = int(os.environ["LOCAL_RANK"]) world_size = int(os.environ["WORLD_SIZE"]) torch.cuda.set_device(local_rank) device = torch.device(f"cuda:{local_rank}")
def init_fn(x: torch.Tensor): with torch.nograd(): x.fill(2.0) ebc = torchrec.EmbeddingBagCollection( device=torch.device("meta"), tables=[ torchrec.EmbeddingBagConfig( name="product_table", embedding_dim=4, num_embeddings=4, feature_names=["product"], init_fn=init_fn, pooling=torchrec.PoolingType.MEAN, ), ] ) sharding_types = [ShardingType.ROW_WISE.value] constraints = {"product_table": ParameterConstraints(sharding_types=sharding_types)} from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward planner = EmbeddingShardingPlanner( constraints=constraints,
) sharders = [EmbeddingBagCollectionSharder()] plan = planner.collective_plan(ebc, sharders, pg = dist.GroupMember.WORLD)
apply_optimizer_in_backward( optimizer_class=torch.optim.SGD, params=ebc.parameters(), optimizer_kwargs={"lr": 0.02}, )
model = torchrec.distributed.DistributedModelParallel(ebc, sharders=sharders, device=torch.device("cuda"), plan = plan) mb = torchrec.KeyedJaggedTensor( keys = ["product"], values = torch.tensor([0, 1, 2]).cuda(), # key [0,1] on rank0, [2] on rank 1 lengths = torch.tensor([3], dtype=torch.int64).cuda(), ) ret = model(mb) # => this is awaitable product = ret.to_dict()["product"] # implicitly call awaitable.wait(); ec does not have to_dict attribute
if(local_rank == 0): print(model.plan) print(f'product {product} ') # resut is 4!! (2+2) / 2 + (2) / 1
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun \ --nnodes 1 \ --nproc_per_node 2 \ ./mean_polling.py
module:
------------- | ------------- | -------------- | ------ product_table | row_wise | fused | [0, 1]
------------- | ------------- | ----------- | ------------- product_table | [0, 0] | [2, 4] | rank:0/cuda:0 product_table | [2, 0] | [2, 4] | rank:1/cuda:1
product tensor([[4., 4., 4., 4.]], device='cuda:0', grad_fn=)