pytorch / torchrec

Pytorch domain library for recommendation systems
BSD 3-Clause "New" or "Revised" License
1.92k stars 424 forks source link

[BUG] EBC Mean pooling division is not handled properly. #2362

Closed JacoCheung closed 1 month ago

JacoCheung commented 1 month ago

Describe the bug Hi torchrec team, I found out that if the keys of an input bag were across multiple devices, the mean pooling result was incorrect. The root cause was that fbgemm will divide the embedding results by local bag size , and the output_dist of torchrec is a SUM reduce scatter (RW sharding).

Replicating steps

import os import sys import torch import torchrec import torch.distributed as dist from torchrec.distributed.fbgemm_qcomm_codec import get_qcomm_codecs_registry, QCommsConfig, CommType from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder

from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology, ParameterConstraints from torchrec.distributed.embedding import EmbeddingCollectionSharder from torchrec.distributed.types import ( ModuleSharder, ShardingType, ) dist.init_process_group(backend="nccl")

local_rank = int(os.environ["LOCAL_RANK"]) world_size = int(os.environ["WORLD_SIZE"]) torch.cuda.set_device(local_rank) device = torch.device(f"cuda:{local_rank}")

def init_fn(x: torch.Tensor): with torch.nograd(): x.fill(2.0) ebc = torchrec.EmbeddingBagCollection( device=torch.device("meta"), tables=[ torchrec.EmbeddingBagConfig( name="product_table", embedding_dim=4, num_embeddings=4, feature_names=["product"], init_fn=init_fn, pooling=torchrec.PoolingType.MEAN, ), ] ) sharding_types = [ShardingType.ROW_WISE.value] constraints = {"product_table": ParameterConstraints(sharding_types=sharding_types)} from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward planner = EmbeddingShardingPlanner( constraints=constraints,

) sharders = [EmbeddingBagCollectionSharder()] plan = planner.collective_plan(ebc, sharders, pg = dist.GroupMember.WORLD)

apply_optimizer_in_backward( optimizer_class=torch.optim.SGD, params=ebc.parameters(), optimizer_kwargs={"lr": 0.02}, )

model = torchrec.distributed.DistributedModelParallel(ebc, sharders=sharders, device=torch.device("cuda"), plan = plan) mb = torchrec.KeyedJaggedTensor( keys = ["product"], values = torch.tensor([0, 1, 2]).cuda(), # key [0,1] on rank0, [2] on rank 1 lengths = torch.tensor([3], dtype=torch.int64).cuda(), ) ret = model(mb) # => this is awaitable product = ret.to_dict()["product"] # implicitly call awaitable.wait(); ec does not have to_dict attribute

if(local_rank == 0): print(model.plan) print(f'product {product} ') # resut is 4!! (2+2) / 2 + (2) / 1

- cmd to run:

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun \ --nnodes 1 \ --nproc_per_node 2 \ ./

- Output result:


param     | sharding type | compute kernel | ranks 

------------- | ------------- | -------------- | ------ product_table | row_wise | fused | [0, 1]

param     | shard offsets | shard sizes |   placement  

------------- | ------------- | ----------- | ------------- product_table | [0, 0] | [2, 4] | rank:0/cuda:0 product_table | [2, 0] | [2, 4] | rank:1/cuda:1

product tensor([[4., 4., 4., 4.]], device='cuda:0', grad_fn=)

- expected sould be `product tensor([[2., 2., 2., 2.]]`
JacoCheung commented 1 month ago

Seems like it's addressed in #1772