pytorch / torchrec

Pytorch domain library for recommendation systems
BSD 3-Clause "New" or "Revised" License
1.84k stars 394 forks source link

[Bug] the number of embedddings in ManagedCollisionCollection must be a multiple of the number of devices #1591

Open fangleigit opened 8 months ago

fangleigit commented 8 months ago

when changing the number of embeddings to 4091, and mch_size to 1021 of the code below, it will throw the following exception

ValueError: ShardedTensor global_size property does not match from different ranks! Found global_size=torch.Size([3070]) on rank:0, and global_size=torch.Size([3068]) on rank:1.
ValueError: ShardedTensor global_size property does not match from different ranks! Found global_size=torch.Size([3070]) on rank:0, and global_size=torch.Size([3068]) on rank:1.
Traceback (most recent call last):
  File "test2.py", line 143, in <module>
    spmd_sharing_simulation(ShardingType.ROW_WISE)
  File "test2.py", line 139, in spmd_sharing_simulation
    assert 0 == p.exitcode
AssertionError
import os
from typing import Dict, cast

import multiprocess
import torch
import torch.distributed as dist
import torchrec
from torchrec.distributed.mc_embeddingbag import ManagedCollisionEmbeddingBagCollectionSharder
from torchrec.distributed.model_parallel import DistributedModelParallel
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.planner.types import ParameterConstraints
from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingType
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingBagCollection
from torchrec.modules.mc_modules import (
    DistanceLFU_EvictionPolicy,
    ManagedCollisionCollection,
    ManagedCollisionModule,
    MCHManagedCollisionModule,
)

def preprocess_func(id: torch.Tensor, hash_size: int) -> torch.Tensor:
    return id % hash_size

os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"

table_name = "sample"

tables = [
    torchrec.EmbeddingBagConfig(
        name=table_name,
        embedding_dim=64,
        num_embeddings=4096,
        feature_names=[table_name],
        pooling=torchrec.PoolingType.SUM,
    )
]

mcc = ManagedCollisionCollection(
    managed_collision_modules={table_name: cast(
        ManagedCollisionModule,
        MCHManagedCollisionModule(
            zch_size=3070,
            mch_size=1026,
            device="meta",
            eviction_interval=1,
            eviction_policy=DistanceLFU_EvictionPolicy(),
            mch_hash_func=preprocess_func,
        ),
    )},
    embedding_configs=tables,
)

ebc: ManagedCollisionEmbeddingBagCollection = ManagedCollisionEmbeddingBagCollection(
    EmbeddingBagCollection(
        tables=tables,
        device='meta',
    ),
    mcc,
    return_remapped_features=False,
)

def single_rank_execution(
    rank: int,
    world_size: int,
    constraints: Dict[str, ParameterConstraints],
    module: torch.nn.Module,
    backend: str,
) -> None:

    def init_distributed_single_host(
        rank: int,
        world_size: int,
        backend: str,
        # pyre-fixme[11]: Annotation `ProcessGroup` is not defined as a type.
    ) -> dist.ProcessGroup:
        os.environ["RANK"] = f"{rank}"
        os.environ["WORLD_SIZE"] = f"{world_size}"
        dist.init_process_group(
            rank=rank, world_size=world_size, backend=backend)
        return dist.group.WORLD

    if backend == "nccl":
        device = torch.device(f"cuda:{rank}")
        torch.cuda.set_device(device)
    else:
        device = torch.device("cpu")
    topology = Topology(world_size=world_size, compute_device="cuda")
    pg = init_distributed_single_host(rank, world_size, backend)
    planner = EmbeddingShardingPlanner(
        topology=topology,
        constraints=constraints,
    )
    sharders = [cast(ModuleSharder[torch.nn.Module],
                     ManagedCollisionEmbeddingBagCollectionSharder())]
    plan = planner.collective_plan(module, sharders=None, pg=pg)

    sharded_model = DistributedModelParallel(
        module,
        env=ShardingEnv.from_process_group(pg),
        plan=plan,
        sharders=sharders,
        device=device,
    )
    print(f"rank:{rank},sharding plan: {plan}")
    return sharded_model

def spmd_sharing_simulation(
    sharding_type: ShardingType = ShardingType.TABLE_WISE,
    world_size=2,
):
    ctx = multiprocess.get_context("spawn")
    processes = []
    for rank in range(world_size):
        p = ctx.Process(
            target=single_rank_execution,
            args=(
                rank,
                world_size,
                {
                    table_name: ParameterConstraints(
                        sharding_types=[sharding_type.value],
                    )
                },
                ebc,
                "nccl"
            ),
        )
        p.start()
        processes.append(p)

    for p in processes:
        p.join()
        assert 0 == p.exitcode

if __name__ == '__main__':
    spmd_sharing_simulation(ShardingType.ROW_WISE)
henrylhtsang commented 8 months ago

Hi, thanks for trying out ManagedCollisionCollection!

Not sure if its a bug. The thing is, we are trying to (only) use ManagedCollisionCollection with rowwise sharding, which would shard the table evenly to all the gpus, hence the divisible thing.

fangleigit commented 7 months ago

Hi, thanks for trying out ManagedCollisionCollection!

Not sure if its a bug. The thing is, we are trying to (only) use ManagedCollisionCollection with rowwise sharding, which would shard the table evenly to all the gpus, hence the divisible thing.

Thanks for your quick response, yes, I tried ManagedCollisionCollection on our data, the performance degraded when using ManagedCollisionCollection. The training time is also significant increased. Is there any guideline or document on how to set the hyper-parameters when using this module, e.g., eviction_interval, zch_size, mch_size, and which policy is better DistanceLFU_EvictionPolicy or LFU_EvictionPolicy under which scenario.

henrylhtsang commented 7 months ago

@fangleigit Thanks. We are still actively developing MCH/ZCH, so we don't have a clear answer so far. Let us know if you have it figured out as well!