pytorch / torchrec

Pytorch domain library for recommendation systems
https://pytorch.org/torchrec/
BSD 3-Clause "New" or "Revised" License
1.88k stars 408 forks source link

EmbeddingCollection+KeyedJaggedTensor+vbe the inverse_indices don't work #1895

Open yjjinjie opened 5 months ago

yjjinjie commented 5 months ago

import torch
from torchrec import KeyedJaggedTensor
from torchrec import EmbeddingBagConfig,EmbeddingConfig
from torchrec import EmbeddingBagCollection,EmbeddingCollection

kt = KeyedJaggedTensor(
    keys=['t1', 't2'],
    values=torch.tensor([0,0,0,0,2]),
    lengths=torch.tensor([1,1,1,1,0,1], dtype=torch.int64),
)

kt2 = KeyedJaggedTensor(
    keys=['t1', 't2'],
    values=torch.tensor([0,0,2]),
    lengths=torch.tensor([1,1,0,1], dtype=torch.int64),
    stride_per_key_per_rank=[[1], [3]],
    inverse_indices=(['t1', 't2'], torch.tensor([[0,0,0], [0,1,2]]))
)

eb_configs = [
    EmbeddingBagConfig(
        num_embeddings=100,
        embedding_dim=16,
        name='e1',
        feature_names=['t1']
    ),
    EmbeddingBagConfig(
        num_embeddings=100,
        embedding_dim=16,
        name='e2',
        feature_names=['t2']
    )
]

ebc = EmbeddingBagCollection(eb_configs)
print(ebc(kt)['t1'])
print(ebc(kt2)['t1'])

eb_configs = [
    EmbeddingConfig(
        num_embeddings=100,
        embedding_dim=16,
        name='e1',
        feature_names=['t1']
    ),
    EmbeddingConfig(
        num_embeddings=100,
        embedding_dim=16,
        name='e2',
        feature_names=['t2']
    )
]

ebc = EmbeddingCollection(eb_configs)

print(ebc(kt)["t1"].lengths().size())
print(ebc(kt2)["t1"].lengths().size())

结果: EmbeddingCollection 之后的结果没有根据inverse_indices 进行重新排列,长度为3,1

colin2328 commented 4 months ago

ccn @joshuadeng

joshuadeng commented 4 months ago

hi @yjjinjie, currently EmbeddingCollection does not support variable batch size per feature here. This work is being planned so stay tuned.