pytorch / torchrec

Pytorch domain library for recommendation systems
https://pytorch.org/torchrec/
BSD 3-Clause "New" or "Revised" License
1.95k stars 441 forks source link

[Question/Bug] DP sharding parameters are inconsistent with others. #2563

Open JacoCheung opened 1 week ago

JacoCheung commented 1 week ago

Description: Hi ,torchrec team, I'm using EmbeddingCollection and constrain the sharding type as DATA_PARALLEL. Subsequently, I should be able to get the parameters and pass it to my optimizers. However, there are serveral problems I encoutered.

Looking forward to any input . Thanks!

TableBatchedEmbeddingSlice is not a leaf tensor

  1. The .parameters() or embedding_collection._dmp_wrapped_module.embeddings['0'].weight returns a TableBatchedEmbeddingSlice, which is unexpectedly not a leaf tensor.
    
    ebc = EmbeddingCollection(
    device=torch.device("meta"),
    tables=[
        EmbeddingConfig(
            name="product_table",
            embedding_dim=4,
            num_embeddings=4,
            feature_names=["product"],
            init_fn=init_fn,
        ),
    ]
    )
    ...

model=DMP(ebc) weight = model._dmp_wrapped_module.embeddings['product_table'].weight

This is False

weight.is_leaf

Unfortunately, my app requires such a flag to perform some operations.

## model.bfloat16() detachs the  weight storage 

When I convert the whole model into lower precision, say bf16, the underlying DP tables storage are not affected, however, the weight / params accessor are converted as expected. Then the weight and storage seems to be untied. The optimizer would take no effect on the original storage. A reproducible script is as below: 

import os import torch import torch.distributed as dist import torchrec from torchrec.sparse.jagged_tensor import KeyedJaggedTensor from torchrec.modules.embedding_modules import EmbeddingCollection from torchrec.modules.embedding_configs import EmbeddingConfig

from fbgemm_gpu.split_embedding_configs import SparseType

from torch.distributed.optim import ( _apply_optimizer_in_backward as apply_optimizer_in_backward, ) from torchrec.distributed.embedding import EmbeddingCollectionSharder

from torchrec.distributed.planner import EmbeddingShardingPlanner, ParameterConstraints from torchrec.distributed.types import ( ShardingType, ) from torchrec.optim.optimizers import in_backward_optimizer_filter torch.manual_seed(1024)

os.environ["RANK"] = "0" os.environ["WORLD_SIZE"] = "1" os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "29500" dist.init_process_group(backend="nccl")

def init_fn(x, val=0.1): with torch.no_grad(): dim0 = x.size(0) dim1 = x.size(1) naturalseq = torch.arange(0,dim0).cuda().to(x.dtype).unsqueeze(-1) x.copy(natural_seq.expand(-1, dim1)) ebc = EmbeddingCollection( device=torch.device("meta"), tables=[ EmbeddingConfig( name="product_table", embedding_dim=4, num_embeddings=4, feature_names=["product"], init_fn=init_fn, ), ] ) sharding_types = [ShardingType.DATA_PARALLEL.value] constraints = {"product_table": ParameterConstraints(sharding_types=sharding_types)} from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward planner = EmbeddingShardingPlanner( constraints=constraints, ) sharders = [EmbeddingCollectionSharder(fused_params = {'output_dtype':SparseType.BF16 }) ] plan = planner.collective_plan(ebc, sharders, pg = dist.GroupMember.WORLD)

apply_optimizer_in_backward( optimizer_class=torch.optim.SGD, params=ebc.parameters(), optimizer_kwargs={"lr": 0.02}, )

model = torchrec.distributed.DistributedModelParallel(ebc, sharders=sharders, device=torch.device("cuda"), plan = plan) model = model.bfloat16()

model.module._dmp_wrapped_module.embeddings['product'].weight

ref_optimizer = torch.optim.SGD( dict(in_backward_optimizer_filter(model.named_parameters())).values(), lr=0.02, ) mb = KeyedJaggedTensor( keys = ["product"], values = torch.tensor([0, 1, 2]).cuda(), # key [0,1] on rank0, [2] on rank 1 lengths = torch.tensor([3], dtype=torch.int64).cuda(), ) ret = model(mb)['product'].values() ret.sum().backward() ref_optimizer.step() params = dict(model.named_parameters()) weight = model._dmp_wrapped_module.embeddings['product_table'].weight

print(f"params {params}") print(f'weight { model.module._lookups[0].module._emb_modules[0]._emb_module.split_embedding_weights()}')

The `params` gets updated while the underlying storage `split_embedding_weights` remains the same ( And the next lookup does see the old storage ).

 params:

[-0.0200, -0.0200, -0.0200, -0.0200], [ 0.9805, 0.9805, 0.9805, 0.9805], [ 1.9766, 1.9766, 1.9766, 1.9766], [ 3.0000, 3.0000, 3.0000, 3.0000]

storage:

[0., 0., 0., 0.], [1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]


## grad_fn AsStride when `TableBatchedEmbeddingSlice` as an operand
Besides, I find out the next_functions is started with  an `AsStridedBackward0` ahead of `AccumulateGrad` after if  TableBatchedEmbeddingSlice object is as an operand, for example: 

myself = params['embeddings.product_table.weight'] * torch.ones_like(params['embeddings.product_table.weight'])

this is ((<AsStridedBackward0 object at 0x7ff51a936440>, 0), (None, 0))

myself.grad_fn.next_functions

this one is ((<AccumulateGrad object at 0x7ff51a936440>, 0),)

myself.grad_fn.next_functions[0][0].next_functions



I would like to know why there is an `AsStridedBackward0 `.  One  lib that I depend on requires accessing the `AccumulateGrad` in only one jump.
PaulZhang12 commented 5 days ago

@TroyGarden

TroyGarden commented 4 days ago

sorry for the late response. @JacoCheung it looks like an invalid access to the weights. could you please try using the weights in the state_dict? I'll try to reproduce the issue locally, due to bandwidth limitation likely in next week.

TroyGarden commented 1 day ago

@JacoCheung reguarding quesiont 1 "TableBatchedEmbeddingSlice is not a leaf tensor" I managed to reproduce the issue with world_size =1 As you can see from the following debugger output, the weight is a TableBatchedEmbeddingSlice, which represents a slice of a table batched embedding (even though in the test case it only has one slice). You'll have to use the _original_data to get the original weights

weight = model._dmp_wrapped_module.embeddings['product_table'].weight
weight
Parameter containing:
Parameter(TableBatchedEmbeddingSlice([[ 0.3090,  0.2935, -0.2901,  0.4279],
                            [ 0.3136,  0.2422, -0.0231, -0.0045],
                            [-0.1398, -0.3822,  0.2852, -0.4772],
                            [ 0.3793, -0.3837, -0.4460,  0.0480]],
                           requires_grad=True))
weight.is_leaf
False
weight._original_tensor.is_leaf
True

If you want the leaf weight, you can try using the parameters from the state_dict

w1 = model._dmp_wrapped_module.state_dict()['embeddings.product_table.weight']
w1
ShardedTensor(ShardedTensorMetadata(shards_metadata=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4, 4], placement=rank:0/cpu)], size=torch.Size([4, 4]), tensor_properties=TensorProperties(dtype=torch.float32, layout=torch.strided, requires_grad=False, memory_format=torch.contiguous_format, pin_memory=False)))
w1.is_leaf
True

or

w2 = model._dmp_wrapped_module.embeddings.state_dict()['product_table.weight']
w2
tensor([[ 0.3090,  0.2935, -0.2901,  0.4279],
        [ 0.3136,  0.2422, -0.0231, -0.0045],
        [-0.1398, -0.3822,  0.2852, -0.4772],
        [ 0.3793, -0.3837, -0.4460,  0.0480]])
w2.is_leaf
True

Hope that answered your question.

TroyGarden commented 20 hours ago

@JacoCheung as for question 2, it's kind of similar to question 1.

  1. You can also use the state_dict to access the weights.
  2. when you call model.bfloat16, it actually copies all the parameters in the old model and casts them to bfloat16. However, the TableBatchedEmbeddingSlice is only a slice of the original weights, so its _original_data won't change to the new model's weights (bfloat16), it actually still references to the old model's weights (float32). So after the backward, the old model won't be updated.

My suggestion is do the bfloat16 casting before calling the DMP, or more directly, setting the embedding config with dtype=bfloat16.