Open yjjinjie opened 5 months ago
import torch from torchrec import KeyedJaggedTensor from torchrec import EmbeddingBagConfig,EmbeddingConfig from torchrec import EmbeddingBagCollection,EmbeddingCollection kt = KeyedJaggedTensor( keys=['t1', 't2'], values=torch.tensor([0,0,0,0,2]), lengths=torch.tensor([1,1,1,1,0,1], dtype=torch.int64), ) kt2 = KeyedJaggedTensor( keys=['t1', 't2'], values=torch.tensor([0,0,2]), lengths=torch.tensor([1,1,0,1], dtype=torch.int64), stride_per_key_per_rank=[[1], [3]], inverse_indices=(['t1', 't2'], torch.tensor([[0,0,0], [0,1,2]])) ) eb_configs = [ EmbeddingBagConfig( num_embeddings=100, embedding_dim=16, name='e1', feature_names=['t1'] ), EmbeddingBagConfig( num_embeddings=100, embedding_dim=16, name='e2', feature_names=['t2'] ) ] ebc = EmbeddingBagCollection(eb_configs) print(ebc(kt)['t1']) print(ebc(kt2)['t1']) eb_configs = [ EmbeddingConfig( num_embeddings=100, embedding_dim=16, name='e1', feature_names=['t1'] ), EmbeddingConfig( num_embeddings=100, embedding_dim=16, name='e2', feature_names=['t2'] ) ] ebc = EmbeddingCollection(eb_configs) print(ebc(kt)["t1"].lengths().size()) print(ebc(kt2)["t1"].lengths().size())
结果: EmbeddingCollection 之后的结果没有根据inverse_indices 进行重新排列,长度为3,1
ccn @joshuadeng
hi @yjjinjie, currently EmbeddingCollection does not support variable batch size per feature here. This work is being planned so stay tuned.
结果: EmbeddingCollection 之后的结果没有根据inverse_indices 进行重新排列,长度为3,1