pytorch / torchrec

Pytorch domain library for recommendation systems
BSD 3-Clause "New" or "Revised" License
1.87k stars 406 forks source link

`AttributeError: 'NoneType' object has no attribute '_dynamo_weak_dynamic_indices'` when using row-wise sharding #1890

Open jiannanWang opened 5 months ago

jiannanWang commented 5 months ago

Description

I’m using torch.compile with DistributedModelParallel. Running below code result in AttributeError: 'NoneType' object has no attribute '_dynamo_weak_dynamic_indices'. Note that this seems to only happen when using row-wise sharding. I would expect no such errors when running the above code.

Enviroment:

python=3.11.8, torch= '2.2.2+cu121', torchrec= '0.6.0+cu121'.

Reproduction code:

import os
from typing import Callable, List, Union, Tuple
import multiprocessing

import torch
import torch.distributed as dist
import torch.nn as nn
from torchrec.distributed.model_parallel import DistributedModelParallel
from torchrec.distributed.planner import (
    EmbeddingShardingPlanner,
    Topology,
)
from torchrec.distributed.test_utils.multi_process import MultiProcessContext
from torchrec.distributed.test_utils.test_sharding import create_test_sharder
from torchrec.distributed.test_utils.test_model import (
    ModelInput,
)
from torchrec.distributed.types import (
    ModuleSharder,
    ShardingEnv,
    ShardingPlan,
)
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.sparse.jagged_tensor import KeyedTensor
from torchrec.test_utils import get_free_port

class TestModel(nn.Module):
    def __init__(self):
        super().__init__()
        # define model parameters
        self.dense_in_feature = 820
        self.dense_out_feature = 784
        self.table_params = [
            [311, 108],
            [739, 408],
        ]
        self.weighted_table_params = [
            [159, 96],
            [69, 24],
            [412, 564],
            [940, 300],
        ]
        self.over_out_feature = 61

        # sparse layer
        self.tables = [
            EmbeddingBagConfig(
                num_embeddings=self.table_params[i][0],
                embedding_dim=self.table_params[i][1],
                name="table_" + str(i),
                feature_names=["feature_" + str(i)],
            )
            for i in range(len(self.table_params))
        ]
        self.sparse = EmbeddingBagCollection(
            tables=self.tables,
            is_weighted=False,
        )
        # weighted sparse layer
        self.weighted_tables = [
            EmbeddingBagConfig(
                num_embeddings=self.weighted_table_params[i][0],
                embedding_dim=self.weighted_table_params[i][1],
                name="weighted_table_" + str(i),
                feature_names=["weighted_feature_" + str(i)],
            )
            for i in range(len(self.weighted_table_params))
        ]
        self.sparse_weighted = EmbeddingBagCollection(
            tables=self.weighted_tables, 
            is_weighted=True,
        )
        # dense layer
        self.dense = nn.Linear(in_features=self.dense_in_feature, out_features=self.dense_out_feature, bias=True)
        # over layer
        in_features_concat = (
            self.dense_out_feature
            + sum([table.embedding_dim * len(table.feature_names) for table in self.tables])
            + sum([table.embedding_dim * len(table.feature_names) for table in self.weighted_tables])
        )
        self.over = nn.Linear(in_features=in_features_concat, out_features=self.over_out_feature, bias=True)

    def forward(
        self,
        input: ModelInput,
        print_intermediate_layer: bool = False,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        # dense, sparse, weighted sparse layer output
        dense_r = self.dense(input.float_features)
        sparse_r = self.sparse(input.idlist_features)
        sparse_weighted_r = self.sparse_weighted(input.idscore_features)
        # concat dense, sparse, weighted sparse layer output
        result = KeyedTensor(
            keys=sparse_r.keys() + sparse_weighted_r.keys(),
            length_per_key=sparse_r.length_per_key()
            + sparse_weighted_r.length_per_key(),
            values=torch.cat([sparse_r.values(), sparse_weighted_r.values()], dim=1),
        )
        _features = [feature for table in self.tables for feature in table.feature_names]
        _weighted_features = [feature for table in self.weighted_tables for feature in table.feature_names]

        ret_list = []
        ret_list.append(dense_r)
        for feature_name in _features:
            ret_list.append(result[feature_name])
        for feature_name in _weighted_features:
            ret_list.append(result[feature_name])
        ret_concat = torch.cat(ret_list, dim=1)
        # over layer output
        over_r = self.over(ret_concat)
        # sigmoid output
        pred = torch.sigmoid(torch.mean(over_r, dim=1))

        return pred, (dense_r, sparse_r, sparse_weighted_r, over_r)

def sharding_single_rank_test(
    rank: int,
    world_size: int,
    model,
    inputs,
    sharders: List[ModuleSharder[nn.Module]],
    backend: str,
    compiled = True,
) -> None:

    with MultiProcessContext(rank, world_size, backend) as ctx:

        if compiled:
            model = torch.compile(model)
        local_model = model.to(ctx.device)

        planner = EmbeddingShardingPlanner(
            topology=Topology(
                world_size, ctx.device.type
            ),
        )
        plan: ShardingPlan = planner.collective_plan(local_model, sharders, ctx.pg)

        local_model = DistributedModelParallel(
            local_model,
            env=ShardingEnv.from_process_group(ctx.pg),
            plan=plan,
            sharders=sharders,
            device=ctx.device,
        )

        # Run a single training step of the sharded model.
        local_input = inputs[0][1][rank].to(ctx.device)
        local_pred, (dense_r, sparse_r, sparse_weighted_r, over_r) = local_model(local_input)

        # record the local prediction
        all_local_pred = []
        for _ in range(world_size):
            all_local_pred.append(torch.empty_like(local_pred))
        dist.all_gather(all_local_pred, local_pred, group=ctx.pg)

        # record the local model's layer output
        all_dense_r = []
        for _ in range(world_size):
            all_dense_r.append(torch.empty_like(dense_r))
        dist.all_gather(all_dense_r, dense_r, group=ctx.pg)

        # print(sparse_r.to_dict())
        sparse_r_dict = sparse_r.to_dict()
        all_sparse_r_dict = {}
        for key in sparse_r_dict:
            all_sparse_r_dict[key] = []
            for _ in range(world_size):
                all_sparse_r_dict[key].append(torch.empty_like(sparse_r_dict[key]))
            dist.all_gather(all_sparse_r_dict[key], sparse_r_dict[key].contiguous(), group=ctx.pg)

        sparse_weighted_r_dict = sparse_weighted_r.to_dict()
        all_sparse_weighted_r_dict = {}
        for key in sparse_weighted_r_dict:
            all_sparse_weighted_r_dict[key] = []
            for _ in range(world_size):
                all_sparse_weighted_r_dict[key].append(torch.empty_like(sparse_weighted_r_dict[key]))
            dist.all_gather(all_sparse_weighted_r_dict[key], sparse_weighted_r_dict[key].contiguous(), group=ctx.pg)

        all_over_r = []
        for _ in range(world_size):
            all_over_r.append(torch.empty_like(over_r))
        dist.all_gather(all_over_r, over_r, group=ctx.pg)

def setUp():
    os.environ["MASTER_ADDR"] = str("localhost")
    os.environ["MASTER_PORT"] = str(get_free_port())
    os.environ["GLOO_DEVICE_TRANSPORT"] = "TCP"
    os.environ["NCCL_SOCKET_IFNAME"] = "lo"

    torch.use_deterministic_algorithms(True)
    if torch.cuda.is_available():
        torch.backends.cudnn.allow_tf32 = False
        torch.backends.cuda.matmul.allow_tf32 = False
        os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

def run_multi_process_test(
    callable: Callable[
        ...,
        None,
    ],
    world_size: int,
    # pyre-ignore
    **kwargs,
) -> None:
    setUp()
    ctx = multiprocessing.get_context("forkserver")
    processes = []
    for rank in range(world_size):
        kwargs["rank"] = rank
        kwargs["world_size"] = world_size
        p = ctx.Process(
            target=callable,
            kwargs=kwargs,
        )
        p.start()
        processes.append(p)
    for p in processes:
        p.join()

def main_test(
    sharders: List[ModuleSharder[nn.Module]],
    backend: str,
    world_size: int,
    compiled: bool,
) -> None:
    model = TestModel()
    inputs = [ModelInput.generate(
        batch_size=1200,
        world_size=world_size,
        num_float_features=model.dense_in_feature,
        tables=model.tables,
        weighted_tables=model.weighted_tables,
    )]

    run_multi_process_test(
        callable=sharding_single_rank_test,
        world_size=world_size,
        model=model,
        inputs=inputs,
        sharders=sharders,
        backend=backend,
        compiled=compiled,
    )

if __name__ == "__main__":
    sharders = [create_test_sharder("embedding_bag_collection", "row_wise", "dense")]
    backend = "nccl"
    world_size = 2
    main_test(
        sharders = sharders,
        backend = backend,
        world_size = world_size,
        compiled = True,
    )

Log:

The error message is copied below.

Traceback (most recent call last):
  File "/root/miniconda3/envs/torchrec/lib/python3.11/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/root/miniconda3/envs/torchrec/lib/python3.11/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/mnt/tests/reproduce_nccl_row_wise.py", line 154, in sharding_single_rank_test
    local_pred, (dense_r, sparse_r, sparse_weighted_r, over_r) = local_model(local_input)
                                                                 ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torchrec/distributed/model_parallel.py", line 288, in forward
    return self._dmp_wrapped_module(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/nn/parallel/distributed.py", line 1523, in forward
    else self._run_ddp_forward(*inputs, **kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/nn/parallel/distributed.py", line 1359, in _run_ddp_forward
    return self.module(*inputs, **kwargs)  # type: ignore[index]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/tests/reproduce_nccl_row_wise.py", line 86, in forward
    def forward(
  File "/mnt/tests/reproduce_nccl_row_wise.py", line 93, in resume_in_forward
    sparse_r = self.sparse(input.idlist_features)
  File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torchrec/distributed/types.py", line 747, in forward
    dist_input = self.input_dist(ctx, *input, **kwargs).wait().wait()
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torchrec/distributed/embeddingbag.py", line 756, in input_dist
    def input_dist(
  File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torchrec/distributed/embeddingbag.py", line 782, in resume_in_input_dist
    awaitables.append(input_dist(features_by_shard))
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torchrec/distributed/sharding/rw_sharding.py", line 292, in forward
    def forward(
  File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 901, in forward
    return compiled_fn(full_args)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/utils.py", line 81, in g
    return f(*args)
           ^^^^^^^^
  File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 259, in runtime_wrapper
    t._dynamo_weak_dynamic_indices = o.dynamic_dims.copy()
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute '_dynamo_weak_dynamic_indices'
PaulZhang12 commented 4 months ago

Dynamo isn't yet supported for every sharding combination in TorchRec, cc @IvanKobzarev to verify

jiannanWang commented 4 months ago

Thank you for your reply! Can you please tell me what sharding is supported with Dynamo now?