Open JacoCheung opened 1 week ago
@TroyGarden
sorry for the late response. @JacoCheung it looks like an invalid access to the weights. could you please try using the weights in the state_dict? I'll try to reproduce the issue locally, due to bandwidth limitation likely in next week.
@JacoCheung reguarding quesiont 1 "TableBatchedEmbeddingSlice is not a leaf tensor"
I managed to reproduce the issue with world_size =1
As you can see from the following debugger output, the weight is a TableBatchedEmbeddingSlice
, which represents a slice of a table batched embedding (even though in the test case it only has one slice). You'll have to use the _original_data
to get the original weights
weight = model._dmp_wrapped_module.embeddings['product_table'].weight
weight
Parameter containing:
Parameter(TableBatchedEmbeddingSlice([[ 0.3090, 0.2935, -0.2901, 0.4279],
[ 0.3136, 0.2422, -0.0231, -0.0045],
[-0.1398, -0.3822, 0.2852, -0.4772],
[ 0.3793, -0.3837, -0.4460, 0.0480]],
requires_grad=True))
weight.is_leaf
False
weight._original_tensor.is_leaf
True
If you want the leaf weight, you can try using the parameters from the state_dict
w1 = model._dmp_wrapped_module.state_dict()['embeddings.product_table.weight']
w1
ShardedTensor(ShardedTensorMetadata(shards_metadata=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4, 4], placement=rank:0/cpu)], size=torch.Size([4, 4]), tensor_properties=TensorProperties(dtype=torch.float32, layout=torch.strided, requires_grad=False, memory_format=torch.contiguous_format, pin_memory=False)))
w1.is_leaf
True
or
w2 = model._dmp_wrapped_module.embeddings.state_dict()['product_table.weight']
w2
tensor([[ 0.3090, 0.2935, -0.2901, 0.4279],
[ 0.3136, 0.2422, -0.0231, -0.0045],
[-0.1398, -0.3822, 0.2852, -0.4772],
[ 0.3793, -0.3837, -0.4460, 0.0480]])
w2.is_leaf
True
Hope that answered your question.
@JacoCheung as for question 2, it's kind of similar to question 1.
state_dict
to access the weights.model.bfloat16
, it actually copies all the parameters in the old model and casts them to bfloat16
. However, the TableBatchedEmbeddingSlice
is only a slice of the original weights, so its _original_data
won't change to the new model's weights (bfloat16), it actually still references to the old model's weights (float32). So after the backward, the old model won't be updated.My suggestion is do the bfloat16 casting before calling the DMP, or more directly, setting the embedding config with dtype=bfloat16.
Description: Hi ,torchrec team, I'm using EmbeddingCollection and constrain the sharding type as DATA_PARALLEL. Subsequently, I should be able to get the parameters and pass it to my optimizers. However, there are serveral problems I encoutered.
Looking forward to any input . Thanks!
TableBatchedEmbeddingSlice is not a leaf tensor
.parameters()
orembedding_collection._dmp_wrapped_module.embeddings['0'].weight
returns a TableBatchedEmbeddingSlice, which is unexpectedly not aleaf
tensor.model=DMP(ebc) weight = model._dmp_wrapped_module.embeddings['product_table'].weight
This is False
weight.is_leaf
import os import torch import torch.distributed as dist import torchrec from torchrec.sparse.jagged_tensor import KeyedJaggedTensor from torchrec.modules.embedding_modules import EmbeddingCollection from torchrec.modules.embedding_configs import EmbeddingConfig
from fbgemm_gpu.split_embedding_configs import SparseType
from torch.distributed.optim import ( _apply_optimizer_in_backward as apply_optimizer_in_backward, ) from torchrec.distributed.embedding import EmbeddingCollectionSharder
from torchrec.distributed.planner import EmbeddingShardingPlanner, ParameterConstraints from torchrec.distributed.types import ( ShardingType, ) from torchrec.optim.optimizers import in_backward_optimizer_filter torch.manual_seed(1024)
os.environ["RANK"] = "0" os.environ["WORLD_SIZE"] = "1" os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "29500" dist.init_process_group(backend="nccl")
def init_fn(x, val=0.1): with torch.no_grad(): dim0 = x.size(0) dim1 = x.size(1) naturalseq = torch.arange(0,dim0).cuda().to(x.dtype).unsqueeze(-1) x.copy(natural_seq.expand(-1, dim1)) ebc = EmbeddingCollection( device=torch.device("meta"), tables=[ EmbeddingConfig( name="product_table", embedding_dim=4, num_embeddings=4, feature_names=["product"], init_fn=init_fn, ), ] ) sharding_types = [ShardingType.DATA_PARALLEL.value] constraints = {"product_table": ParameterConstraints(sharding_types=sharding_types)} from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward planner = EmbeddingShardingPlanner( constraints=constraints, ) sharders = [EmbeddingCollectionSharder(fused_params = {'output_dtype':SparseType.BF16 }) ] plan = planner.collective_plan(ebc, sharders, pg = dist.GroupMember.WORLD)
apply_optimizer_in_backward( optimizer_class=torch.optim.SGD, params=ebc.parameters(), optimizer_kwargs={"lr": 0.02}, )
model = torchrec.distributed.DistributedModelParallel(ebc, sharders=sharders, device=torch.device("cuda"), plan = plan) model = model.bfloat16()
model.module._dmp_wrapped_module.embeddings['product'].weight
ref_optimizer = torch.optim.SGD( dict(in_backward_optimizer_filter(model.named_parameters())).values(), lr=0.02, ) mb = KeyedJaggedTensor( keys = ["product"], values = torch.tensor([0, 1, 2]).cuda(), # key [0,1] on rank0, [2] on rank 1 lengths = torch.tensor([3], dtype=torch.int64).cuda(), ) ret = model(mb)['product'].values() ret.sum().backward() ref_optimizer.step() params = dict(model.named_parameters()) weight = model._dmp_wrapped_module.embeddings['product_table'].weight
print(f"params {params}") print(f'weight { model.module._lookups[0].module._emb_modules[0]._emb_module.split_embedding_weights()}')
[-0.0200, -0.0200, -0.0200, -0.0200], [ 0.9805, 0.9805, 0.9805, 0.9805], [ 1.9766, 1.9766, 1.9766, 1.9766], [ 3.0000, 3.0000, 3.0000, 3.0000]
[0., 0., 0., 0.], [1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]
myself = params['embeddings.product_table.weight'] * torch.ones_like(params['embeddings.product_table.weight'])
this is ((<AsStridedBackward0 object at 0x7ff51a936440>, 0), (None, 0))
myself.grad_fn.next_functions
this one is ((<AccumulateGrad object at 0x7ff51a936440>, 0),)
myself.grad_fn.next_functions[0][0].next_functions