pytorch / torchrec

Pytorch domain library for recommendation systems
https://pytorch.org/torchrec/
BSD 3-Clause "New" or "Revised" License
1.95k stars 441 forks source link

[Question] Does TorchRec supports checking point / (load/save) #2534

Open JacoCheung opened 2 weeks ago

JacoCheung commented 2 weeks ago

Hi, team, I would like to know how to load and dump a sharded embedding collection via state_dict. Basically

  1. How many files should I save? Should each rank have an exclusive sharding file or only single rank collectively gather the whole embedding and stores as one file? How should I handle the case where both DP and MP are applied?

  2. If each rank maintains a sharding file, how can I load and re-shard in a new distributed environment where the number of GPUs vary from the saved model.

  3. If there is one saved file, how should I load and re-shard especially in multi-node env?

It's more helpful if anyone can provide a sample code! Thanks!

iamzainhuda commented 2 weeks ago

1) You should have the checkpoint per rank, since we do not collectively gather the whole embedding. If you wanted to, you could do that and then reconstruct the sharded state dict by the original sharding plan. Although I wouldn't recommend this. You should be able to use torch.distributed.checkpoint utilities for TorchRec models.

2) For changing the number of GPU's, you would need to understand how the sharding changes. Am I correct in understanding you would want to go from a model sharded on 8 GPUs to load onto 16 GPUs? Resharding here would be important, which you would also have to do yourself before you load. TorchRec doesn't have any utilities surrounding this.

3) You can broadcast the parameters/state to the other ranks as you load, as a pre_load_state_dict_hook on top of DMP.

JacoCheung commented 4 days ago

Thansk for your reply! close it as it's completed

JacoCheung commented 1 day ago

Sorry @iamzainhuda , I have to reopen it because I encountered another issue regarding Adam optimizer. Say if I have an embeddingcollection whose optimizer is fused with backwawrd. But the optimizer.state_dict() returns nothing but only "mometum1/2" tensor, other state like lr, decay are gone. I think the problem is here.


import os
import sys
sys.path.append(os.path.abspath('/home/scratch.junzhang_sw/workspace/torchrec'))
import torch
import torchrec
import torch.distributed as dist

os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
dist.init_process_group(backend="nccl")
ebc = torchrec.EmbeddingCollection(
    device=torch.device("meta"),
    tables=[
        torchrec.EmbeddingConfig(
            name="product_table",
            embedding_dim=64,
            num_embeddings=4096,
            feature_names=["product"],
        ),
        torchrec.EmbeddingConfig(
            name="user_table",
            embedding_dim=64,
            num_embeddings=4096,
            feature_names=["user"],
        )
    ]
)

from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward

apply_optimizer_in_backward(
    optimizer_class=torch.optim.Adam,
    params=ebc.parameters(),
    optimizer_kwargs={"lr": 0.02},
)

from torchrec.distributed.fbgemm_qcomm_codec import get_qcomm_codecs_registry, QCommsConfig, CommType
from torchrec.distributed.embedding import EmbeddingCollectionSharder

sharder = EmbeddingCollectionSharder(
    # qcomm_codecs_registry=get_qcomm_codecs_registry(
    #         qcomms_config=QCommsConfig(
    #             forward_precision=CommType.FP16,
    #             backward_precision=CommType.BF16,
                  use_index_dedup=True,
    #         )
    #     )
)
dp_rank = dist.get_rank()
model = torchrec.distributed.DistributedModelParallel(ebc, sharders=[sharder], device=torch.device("cuda"))
mb = torchrec.KeyedJaggedTensor(
    keys = ["product", "user"],
    values = torch.tensor([101, 201, 101, 404, 404, 606, 606, 606]).cuda(),
    lengths = torch.tensor([2, 0, 1, 1, 1, 3], dtype=torch.int64).cuda(),
)
import pdb;pdb.set_trace()
ret = model(mb) # => this is awaitable
product = ret['product'] # implicitly call awaitable.wait()
# import pdb;pdb.set_trace()

Above model gives me optimizer state like:

>>> model.fused_optimizer.state_dict()['state']['embeddings.product_table.weight'].keys()
dict_keys(['product_table.momentum1', 'product_table.exp_avg_sq'])