Open ZhuYuJin opened 5 months ago
We can't really look into the gradients. https://github.com/pytorch/torchrec/issues/1293
I guess Colin was saying if you switch to dense compute kernel, then you can see it. Though I haven't tried it myself.
@colin2328 Hi, as I also encountered the same problem, and after searching some issues, I found that the grad can be provided in dense compute kernel. But I think dense compute kernel was diabled due to this commit: https://github.com/pytorch/torchrec/commit/35c05fa00d2ca7e7b6fd577374560a5174be2675
Do you have any other ways to run a embedding collection in dense mode?
I found one way is to implement my own Sharder
class EmbeddingCollectionSharderDense(EmbeddingCollectionSharder):
def compute_kernels(self, sharding_type: str, compute_device_type: str) -> List[str]:
return super().compute_kernels(sharding_type, compute_device_type) + [EmbeddingComputeKernel.DENSE.value]
@Ye-Tian-Zero see to my answer in https://github.com/pytorch/torchrec/issues/1741
tldr: for non-data parallel sharding type, you get better performance by using FUSED.
We didn't disable DENSE. We disabled DENSE in most tests, since its slower than FUSED.
For data parallel, you must use DENSE.
@henrylhtsang Hi, thank you, but I think dense compute kernel was disabled here:
which means Dense kernel was no longer a valid option for non-data-parallel module.
What I mean by 'disable' is that it can no longer be used with sharding modules.
@Ye-Tian-Zero It can still be used. Let me know if you encounter any problem using it. You can configure it through ParameterConstraints with EmbeddingComputeKernel.
The reason it was disabled in the tests is to save testing capacity, since DENSE is slower than FUSED and hence we usually recommend people to use FUSED. But for debugging purposes, feel free to use DENSE.
@henrylhtsang Sorry for the late reply, maybe you can try this script with a gpu machine:
#!/usr/bin/env python
# coding: utf-8
# In[1]:
import os
import torch
import torchrec
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "21500"
# In[2]:
from torchrec.distributed.planner.types import ParameterConstraints
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.types import ShardingType
from typing import Dict
large_table_cnt = 2
small_table_cnt = 2
large_tables=[
torchrec.EmbeddingConfig(
name="large_table_" + str(i),
embedding_dim=64,
num_embeddings=4096,
feature_names=["large_table_feature_" + str(i)],
) for i in range(large_table_cnt)
]
small_tables=[
torchrec.EmbeddingConfig(
name="small_table_" + str(i),
embedding_dim=64,
num_embeddings=1024,
feature_names=["small_table_feature_" + str(i)],
) for i in range(small_table_cnt)
]
def gen_constraints(sharding_type: ShardingType = ShardingType.TABLE_WISE) -> Dict[str, ParameterConstraints]:
large_table_constraints = {
"product_table": ParameterConstraints(
sharding_types=[sharding_type.value],
compute_kernels=[EmbeddingComputeKernel.DENSE.value]
)
}
small_table_constraints = {
"user_table":ParameterConstraints(
sharding_types=[sharding_type.value],
compute_kernels=[EmbeddingComputeKernel.DENSE.value],
)
}
constraints = {**large_table_constraints, **small_table_constraints}
return constraints
# In[3]:
def single_rank_execution(
rank: int,
world_size: int,
constraints: Dict[str, ParameterConstraints],
module: torch.nn.Module,
backend: str,
) -> None:
import os
import torch
import torch.distributed as dist
from torchrec.distributed.model_parallel import DistributedModelParallel
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.types import ModuleSharder, ShardingEnv
from typing import cast
import torchrec
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.embedding import EmbeddingCollectionSharder
from typing import List
def init_distributed_single_host(
rank: int,
world_size: int,
backend: str,
# pyre-fixme[11]: Annotation `ProcessGroup` is not defined as a type.
) -> dist.ProcessGroup:
os.environ["RANK"] = f"{rank}"
os.environ["WORLD_SIZE"] = f"{world_size}"
dist.init_process_group(rank=rank, world_size=world_size, backend=backend)
return dist.group.WORLD
if backend == "nccl":
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
else:
device = torch.device("cpu")
topology = Topology(world_size=world_size, compute_device="cuda")
pg = init_distributed_single_host(rank, world_size, backend)
planner = EmbeddingShardingPlanner(
topology=topology,
constraints=constraints,
)
sharders = [cast(ModuleSharder[torch.nn.Module], EmbeddingCollectionSharder())]
plan: ShardingPlan = planner.collective_plan(module, sharders, pg)
print(plan)
sharded_model = DistributedModelParallel(
module,
env=ShardingEnv.from_process_group(pg),
plan=plan,
sharders=sharders,
device=device,
)
mb = torchrec.KeyedJaggedTensor(
keys = ["product", "user"],
values = torch.tensor([101, 202, 303, 404, 505, 606]).cuda(),
lengths = torch.tensor([2, 0, 1, 1, 1, 1], dtype=torch.int64).cuda(),
)
print(sharded_model(mb)['product'].to_padded_dense().sum())
(sharded_model(mb)['product'].to_padded_dense().sum() + sharded_model(mb)['user'].to_padded_dense().sum()).backward()
print([p.grad.sum() for p in sharded_model.parameters()])
print(f"rank:{rank},sharding plan: {plan}")
return sharded_model
# In[4]:
import multiprocess
def spmd_sharing_simulation(
sharding_type: ShardingType = ShardingType.TABLE_WISE,
world_size = 2,
):
ctx = multiprocess.get_context("spawn")
processes = []
for rank in range(world_size):
p = ctx.Process(
target=single_rank_execution,
args=(
rank,
world_size,
gen_constraints(sharding_type),
ebc,
"nccl"
),
)
p.start()
processes.append(p)
for p in processes:
p.join()
assert 0 == p.exitcode
# In[5]:
ebc = torchrec.EmbeddingCollection(
device=torch.device("cuda"),
tables=[
torchrec.EmbeddingConfig(
name="product_table",
embedding_dim=64,
num_embeddings=4096,
feature_names=["product"],
),
torchrec.EmbeddingConfig(
name="user_table",
embedding_dim=64,
num_embeddings=4096,
feature_names=["user"],
)
]
)
# In[6]:
spmd_sharing_simulation(ShardingType.ROW_WISE)
This code will result in the error bellow eventutally:
No available compute kernels after applying user provided constraints for product_table
Process SpawnProcess-1:
Traceback (most recent call last):
File "/data/click-attribution/venv/lib/python3.8/site-packages/multiprocess/process.py", line 315, in _bootstrap
self.run()
File "/data/click-attribution/venv/lib/python3.8/site-packages/multiprocess/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/tmp/ipykernel_30517/3096772556.py", line 43, in single_rank_execution
File "/data/click-attribution/venv/lib/python3.8/site-packages/torchrec/distributed/planner/planners.py", line 187, in collective_plan
return invoke_on_rank_and_broadcast_result(
File "/data/click-attribution/venv/lib/python3.8/site-packages/torchrec/distributed/collective_utils.py", line 53, in invoke_on_rank_and_broadcast_result
res = func(*args, **kwargs)
File "/data/click-attribution/venv/lib/python3.8/site-packages/torchrec/distributed/planner/planners.py", line 217, in plan
search_space = self._enumerator.enumerate(
File "/data/click-attribution/venv/lib/python3.8/site-packages/torchrec/distributed/planner/enumerators.py", line 166, in enumerate
raise RuntimeError(
RuntimeError: No available sharding type and compute kernel combination after applying user provided constraints for product_table
Process SpawnProcess-2:
Traceback (most recent call last):
File "/data/click-attribution/venv/lib/python3.8/site-packages/multiprocess/process.py", line 315, in _bootstrap
self.run()
File "/data/click-attribution/venv/lib/python3.8/site-packages/multiprocess/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/tmp/ipykernel_30517/3096772556.py", line 43, in single_rank_execution
File "/data/click-attribution/venv/lib/python3.8/site-packages/torchrec/distributed/planner/planners.py", line 187, in collective_plan
return invoke_on_rank_and_broadcast_result(
File "/data/click-attribution/venv/lib/python3.8/site-packages/torchrec/distributed/collective_utils.py", line 58, in invoke_on_rank_and_broadcast_result
dist.broadcast_object_list(object_list, rank, group=pg)
File "/data/click-attribution/venv/lib/python3.8/site-packages/torch/distributed/c10d_logger.py", line 47, in wrapper
return func(*args, **kwargs)
File "/data/click-attribution/venv/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 2603, in broadcast_object_list
broadcast(object_sizes_tensor, src=src, group=group)
File "/data/click-attribution/venv/lib/python3.8/site-packages/torch/distributed/c10d_logger.py", line 47, in wrapper
return func(*args, **kwargs)
File "/data/click-attribution/venv/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 1906, in broadcast
work = default_pg.broadcast([tensor], opts)
RuntimeError: [1] is setting up NCCL communicator and retrieving ncclUniqueId from [0] via c10d key-value store by key '0', but store->get('0') got error: Connection reset by peer. This may indicate a possible application crash on rank 0 or a network set up issue.
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
Cell In[6], line 1
----> 1 spmd_sharing_simulation(ShardingType.ROW_WISE)
Cell In[4], line 25, in spmd_sharing_simulation(sharding_type, world_size)
23 for p in processes:
24 p.join()
---> 25 assert 0 == p.exitcode
AssertionError:
@Ye-Tian-Zero okay you are right, we did ban that combo. Can you try to allow DENSE here? https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/embedding_types.py#L415
I try to test the demo scripts with the following command.
torchx run -s local_cwd dist.ddp -j 1x1 --script test_installation.py
I try to print the gradients of embedding_bag_collection. I can observe the gradients of mlp linear layer. However, the gradients of embedding_bag_collection seem to be None.![image](https://github.com/pytorch/torchrec/assets/5879410/5e3ae0ef-5320-4ebd-b997-9f88dce3b987)