pytorch / torchrec

Pytorch domain library for recommendation systems
BSD 3-Clause "New" or "Revised" License
1.87k stars 406 forks source link

The optimizer state key names differ when using `data_parallel` for embedding sharding compared to when using `row_wise` #2394

Open tiankongdeguiji opened 4 days ago

tiankongdeguiji commented 4 days ago

We can reproduce this problem using the following command: torchrun --master_addr=127.0.0.1 --master_port=1234 --nnodes=1 --nproc-per-node=1 --node_rank=0 test_optimizer_state.py --sharding_type $SHARDING_TYPE, and use the enviroment torchrec==0.8.0+cu121, torch==2.4.0+cu121, fbgemm-gpu==0.8.0+cu121

when SHARDING_TYPE=row_wise, it will print

['state.sparse.ebc.embedding_bags.table_0.weight.table_0.momentum1', 'state.sparse.ebc.embedding_bags.table_0.weight.table_0.exp_avg_sq', ...]

when SHARDING_TYPE=data_parallel, it will print

['state.sparse.ebc.embedding_bags.table_0.weight.step', 'state.sparse.ebc.embedding_bags.table_0.weight.exp_avg', 'state.sparse.ebc.embedding_bags.table_0.weight.exp_avg_sq', ...]

xxx.weight.table_0.momentum1 -> xxx.weight.exp_avg,xxx.weight.table_0.exp_avg_sq -> xxx.weight.exp_avg_sq

We may load the model to continue training on clusters with different scales, which can lead to different Sharding Plans, and consequently result in the optimizer's parameters not being loaded correctly.

test_optimizer_state.py

import os
import torch
import argparse

from torch import distributed as dist
from torch.distributed.checkpoint._nested_dict import flatten_state_dict

from torchrec.distributed.comm import get_local_size
from torchrec.distributed.planner import EmbeddingShardingPlanner
from torchrec.distributed.model_parallel import (
    DistributedModelParallel,
    get_default_sharders
)
from torchrec.distributed.planner.types import Topology, ParameterConstraints
from torchrec.distributed.embedding_types import ShardingType
from torchrec.distributed.test_utils.test_model import TestSparseNN, ModelInput
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.distributed.planner.storage_reservations import (
    HeuristicalStorageReservation,
)
from torchrec.optim import optimizers
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizerWrapper
from torchrec.optim.optimizers import in_backward_optimizer_filter
from torchrec.optim.apply_optimizer_in_backward import (
    apply_optimizer_in_backward,  # NOQA
)

parser = argparse.ArgumentParser()
parser.add_argument(
    "--sharding_type",
    type=str,
    default="data_parallel"
)
args, extra_args = parser.parse_known_args()

BATCH_SIZE = 8196
rank = int(os.environ.get("LOCAL_RANK", 0))
device: torch.device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
dist.init_process_group(backend='nccl')

tables = [
    EmbeddingBagConfig(
        num_embeddings=100,
        embedding_dim=16,
        name="table_" + str(i),
        feature_names=["feature_" + str(i)],
    )
    for i in range(4)
]
topology = Topology(
    local_world_size=get_local_size(),
    world_size=dist.get_world_size(),
    compute_device=device.type
)
constrains = {
    t.name: ParameterConstraints(sharding_types=[args.sharding_type])  
    for t in tables
}
planner = EmbeddingShardingPlanner(
    topology=topology,
    batch_size=BATCH_SIZE,
    debug=True,
    storage_reservation = HeuristicalStorageReservation(
        percentage=0.7
    ),
    constraints=constrains
)
model = TestSparseNN(tables=tables, num_float_features=10, sparse_device=torch.device("meta"))

apply_optimizer_in_backward(
    optimizers.Adam, model.sparse.parameters(), {"lr": 0.01}
)
plan = planner.collective_plan(
    model, get_default_sharders(), dist.GroupMember.WORLD
)
# print(plan)
model = DistributedModelParallel(module=model, device=device, plan=plan)
dense_optimizer = KeyedOptimizerWrapper(
    dict(in_backward_optimizer_filter(model.named_parameters())),
    lambda params: torch.optim.Adam(params, lr=0.001),
)
optimizer = CombinedOptimizer([model.fused_optimizer, dense_optimizer])

_, local_batchs = ModelInput.generate(
    batch_size=BATCH_SIZE,
    world_size=int(os.environ.get("WORLD_SIZE", 0)),
    num_float_features=10,
    tables=tables,
    weighted_tables=[]
)
loss, _ = model.forward(local_batchs[rank].to(device))
torch.sum(loss).backward()
optimizer.step()

print(flatten_state_dict(optimizer.state_dict())[0].keys())
tiankongdeguiji commented 4 days ago

Hi, @henrylhtsang @IvanKobzarev @joshuadeng @PaulZhang12 can you see this problem? I think it may related to the code here, https://github.com/pytorch/torchrec/blob/release/v0.8.0/torchrec/distributed/batched_embedding_kernel.py#L472

tiankongdeguiji commented 11 hours ago

Hi, @sarckk @TroyGarden can you see this problem?