pytorch / torchrec

Pytorch domain library for recommendation systems
BSD 3-Clause "New" or "Revised" License
1.78k stars 380 forks source link

Cannot compute the gradients of EmbeddingBagCollection. #1614

Open ZhuYuJin opened 5 months ago

ZhuYuJin commented 5 months ago

I try to test the demo scripts with the following command. torchx run -s local_cwd dist.ddp -j 1x1 --script test_installation.py

I try to print the gradients of embedding_bag_collection. I can observe the gradients of mlp linear layer. However, the gradients of embedding_bag_collection seem to be None. image

henrylhtsang commented 5 months ago

We can't really look into the gradients. https://github.com/pytorch/torchrec/issues/1293

I guess Colin was saying if you switch to dense compute kernel, then you can see it. Though I haven't tried it myself.

Ye-Tian-Zero commented 3 months ago

@colin2328 Hi, as I also encountered the same problem, and after searching some issues, I found that the grad can be provided in dense compute kernel. But I think dense compute kernel was diabled due to this commit: https://github.com/pytorch/torchrec/commit/35c05fa00d2ca7e7b6fd577374560a5174be2675

Do you have any other ways to run a embedding collection in dense mode?

Ye-Tian-Zero commented 3 months ago

I found one way is to implement my own Sharder

class EmbeddingCollectionSharderDense(EmbeddingCollectionSharder):
        def compute_kernels(self, sharding_type: str, compute_device_type: str) -> List[str]:
            return super().compute_kernels(sharding_type, compute_device_type) + [EmbeddingComputeKernel.DENSE.value]
henrylhtsang commented 3 months ago

@Ye-Tian-Zero see to my answer in https://github.com/pytorch/torchrec/issues/1741

tldr: for non-data parallel sharding type, you get better performance by using FUSED.

We didn't disable DENSE. We disabled DENSE in most tests, since its slower than FUSED.

For data parallel, you must use DENSE.

Ye-Tian-Zero commented 2 months ago

@henrylhtsang Hi, thank you, but I think dense compute kernel was disabled here:

https://github.com/pytorch/torchrec/commit/35c05fa00d2ca7e7b6fd577374560a5174be2675#diff-72b4a4f205f4de558b2f5731a6697ae80483ee532eb2012d59999e5d7c462200L302-R302

which means Dense kernel was no longer a valid option for non-data-parallel module.

Ye-Tian-Zero commented 2 months ago

What I mean by 'disable' is that it can no longer be used with sharding modules.

henrylhtsang commented 2 months ago

@Ye-Tian-Zero It can still be used. Let me know if you encounter any problem using it. You can configure it through ParameterConstraints with EmbeddingComputeKernel.

The reason it was disabled in the tests is to save testing capacity, since DENSE is slower than FUSED and hence we usually recommend people to use FUSED. But for debugging purposes, feel free to use DENSE.

Ye-Tian-Zero commented 1 month ago

@henrylhtsang Sorry for the late reply, maybe you can try this script with a gpu machine:

#!/usr/bin/env python
# coding: utf-8

# In[1]:

import os
import torch
import torchrec

os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "21500"

# In[2]:

from torchrec.distributed.planner.types import ParameterConstraints
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.types import ShardingType
from typing import Dict

large_table_cnt = 2
small_table_cnt = 2
large_tables=[
  torchrec.EmbeddingConfig(
    name="large_table_" + str(i),
    embedding_dim=64,
    num_embeddings=4096,
    feature_names=["large_table_feature_" + str(i)],
  ) for i in range(large_table_cnt)
]
small_tables=[
  torchrec.EmbeddingConfig(
    name="small_table_" + str(i),
    embedding_dim=64,
    num_embeddings=1024,
    feature_names=["small_table_feature_" + str(i)],
  ) for i in range(small_table_cnt)
]

def gen_constraints(sharding_type: ShardingType = ShardingType.TABLE_WISE) -> Dict[str, ParameterConstraints]:
  large_table_constraints = {
    "product_table": ParameterConstraints(
          sharding_types=[sharding_type.value],
          compute_kernels=[EmbeddingComputeKernel.DENSE.value]
        ) 

  }
  small_table_constraints = {
    "user_table":ParameterConstraints(
          sharding_types=[sharding_type.value],
          compute_kernels=[EmbeddingComputeKernel.DENSE.value],
    ) 
  }
  constraints = {**large_table_constraints, **small_table_constraints}
  return constraints

# In[3]:

def single_rank_execution(
    rank: int,
    world_size: int,
    constraints: Dict[str, ParameterConstraints],
    module: torch.nn.Module,
    backend: str,
) -> None:
    import os
    import torch
    import torch.distributed as dist
    from torchrec.distributed.model_parallel import DistributedModelParallel
    from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
    from torchrec.distributed.types import ModuleSharder, ShardingEnv
    from typing import cast
    import torchrec
    from torchrec.distributed.embedding_types import EmbeddingComputeKernel
    from torchrec.distributed.embedding import EmbeddingCollectionSharder
    from typing import List

    def init_distributed_single_host(
        rank: int,
        world_size: int,
        backend: str,
        # pyre-fixme[11]: Annotation `ProcessGroup` is not defined as a type.
    ) -> dist.ProcessGroup:
        os.environ["RANK"] = f"{rank}"
        os.environ["WORLD_SIZE"] = f"{world_size}"
        dist.init_process_group(rank=rank, world_size=world_size, backend=backend)
        return dist.group.WORLD

    if backend == "nccl":
        device = torch.device(f"cuda:{rank}")
        torch.cuda.set_device(device)
    else:
        device = torch.device("cpu")
    topology = Topology(world_size=world_size, compute_device="cuda")
    pg = init_distributed_single_host(rank, world_size, backend)
    planner = EmbeddingShardingPlanner(
        topology=topology,
        constraints=constraints,
    )
    sharders = [cast(ModuleSharder[torch.nn.Module], EmbeddingCollectionSharder())]
    plan: ShardingPlan = planner.collective_plan(module, sharders, pg)
    print(plan)
    sharded_model = DistributedModelParallel(
        module,
        env=ShardingEnv.from_process_group(pg),
        plan=plan,
        sharders=sharders,
        device=device,
    )
    mb = torchrec.KeyedJaggedTensor(
        keys = ["product", "user"],
        values = torch.tensor([101, 202, 303, 404, 505, 606]).cuda(),
        lengths = torch.tensor([2, 0, 1, 1, 1, 1], dtype=torch.int64).cuda(),
    )
    print(sharded_model(mb)['product'].to_padded_dense().sum())
    (sharded_model(mb)['product'].to_padded_dense().sum() + sharded_model(mb)['user'].to_padded_dense().sum()).backward()
    print([p.grad.sum() for p in sharded_model.parameters()])
    print(f"rank:{rank},sharding plan: {plan}")
    return sharded_model

# In[4]:

import multiprocess

def spmd_sharing_simulation(
    sharding_type: ShardingType = ShardingType.TABLE_WISE,
    world_size = 2,
):
  ctx = multiprocess.get_context("spawn")
  processes = []
  for rank in range(world_size):
      p = ctx.Process(
          target=single_rank_execution,
          args=(
              rank,
              world_size,
              gen_constraints(sharding_type),
              ebc,
              "nccl"
          ),
      )
      p.start()
      processes.append(p)

  for p in processes:
      p.join()
      assert 0 == p.exitcode

# In[5]:

ebc = torchrec.EmbeddingCollection(
    device=torch.device("cuda"),
    tables=[
        torchrec.EmbeddingConfig(
            name="product_table",
            embedding_dim=64,
            num_embeddings=4096,
            feature_names=["product"],
        ),
        torchrec.EmbeddingConfig(
            name="user_table",
            embedding_dim=64,
            num_embeddings=4096,
            feature_names=["user"],
        )
    ]
)

# In[6]:

spmd_sharing_simulation(ShardingType.ROW_WISE)

This code will result in the error bellow eventutally:

No available compute kernels after applying user provided constraints for product_table
Process SpawnProcess-1:
Traceback (most recent call last):
  File "/data/click-attribution/venv/lib/python3.8/site-packages/multiprocess/process.py", line 315, in _bootstrap
    self.run()
  File "/data/click-attribution/venv/lib/python3.8/site-packages/multiprocess/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/tmp/ipykernel_30517/3096772556.py", line 43, in single_rank_execution
  File "/data/click-attribution/venv/lib/python3.8/site-packages/torchrec/distributed/planner/planners.py", line 187, in collective_plan
    return invoke_on_rank_and_broadcast_result(
  File "/data/click-attribution/venv/lib/python3.8/site-packages/torchrec/distributed/collective_utils.py", line 53, in invoke_on_rank_and_broadcast_result
    res = func(*args, **kwargs)
  File "/data/click-attribution/venv/lib/python3.8/site-packages/torchrec/distributed/planner/planners.py", line 217, in plan
    search_space = self._enumerator.enumerate(
  File "/data/click-attribution/venv/lib/python3.8/site-packages/torchrec/distributed/planner/enumerators.py", line 166, in enumerate
    raise RuntimeError(
RuntimeError: No available sharding type and compute kernel combination after applying user provided constraints for product_table
Process SpawnProcess-2:
Traceback (most recent call last):
  File "/data/click-attribution/venv/lib/python3.8/site-packages/multiprocess/process.py", line 315, in _bootstrap
    self.run()
  File "/data/click-attribution/venv/lib/python3.8/site-packages/multiprocess/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/tmp/ipykernel_30517/3096772556.py", line 43, in single_rank_execution
  File "/data/click-attribution/venv/lib/python3.8/site-packages/torchrec/distributed/planner/planners.py", line 187, in collective_plan
    return invoke_on_rank_and_broadcast_result(
  File "/data/click-attribution/venv/lib/python3.8/site-packages/torchrec/distributed/collective_utils.py", line 58, in invoke_on_rank_and_broadcast_result
    dist.broadcast_object_list(object_list, rank, group=pg)
  File "/data/click-attribution/venv/lib/python3.8/site-packages/torch/distributed/c10d_logger.py", line 47, in wrapper
    return func(*args, **kwargs)
  File "/data/click-attribution/venv/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 2603, in broadcast_object_list
    broadcast(object_sizes_tensor, src=src, group=group)
  File "/data/click-attribution/venv/lib/python3.8/site-packages/torch/distributed/c10d_logger.py", line 47, in wrapper
    return func(*args, **kwargs)
  File "/data/click-attribution/venv/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 1906, in broadcast
    work = default_pg.broadcast([tensor], opts)
RuntimeError: [1] is setting up NCCL communicator and retrieving ncclUniqueId from [0] via c10d key-value store by key '0', but store->get('0') got error: Connection reset by peer. This may indicate a possible application crash on rank 0 or a network set up issue.
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[6], line 1
----> 1 spmd_sharing_simulation(ShardingType.ROW_WISE)

Cell In[4], line 25, in spmd_sharing_simulation(sharding_type, world_size)
     23 for p in processes:
     24     p.join()
---> 25     assert 0 == p.exitcode

AssertionError: 
henrylhtsang commented 1 month ago

@Ye-Tian-Zero okay you are right, we did ban that combo. Can you try to allow DENSE here? https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/embedding_types.py#L415