pytorch / torchrec

Pytorch domain library for recommendation systems
https://pytorch.org/torchrec/
BSD 3-Clause "New" or "Revised" License
1.88k stars 408 forks source link

How to share embeddings of some features between two EmbeddingBagCollections? #1916

Open tiankongdeguiji opened 5 months ago

tiankongdeguiji commented 5 months ago

In a two-tower retrieval model, it is essential to randomly sample negative items. Typically, this means that the batch size for the item tower will be larger than that for the user tower. Consequently, using a single EmbeddingBagCollection proves to be inadequate for this setup. When employing two separate EmbeddingBagCollections, how to share embeddings of some features between two EmbeddingBagCollections?

tiankongdeguiji commented 5 months ago

Hi, @henrylhtsang @IvanKobzarev @joshuadeng @PaulZhang12 can you see this problem?

henrylhtsang commented 5 months ago
tiankongdeguiji commented 5 months ago
  • [not recommended] use padding so they have same length
  • use VBE in kjt

@henrylhtsang If we use VBE, the output of the user tower and item tower will also be padded to the same batch size. Is this approach efficient?

henrylhtsang commented 5 months ago

@tiankongdeguiji oh I misspoke. I meant to say you either need to pad the inputs, or use VBE. iirc you shouldn't need to pad the outputs of VBE, but admittedly I am not familiar with that part

tiankongdeguiji commented 5 months ago

@tiankongdeguiji oh I misspoke. I meant to say you either need to pad the inputs, or use VBE. iirc you shouldn't need to pad the outputs of VBE, but admittedly I am not familiar with that part

@henrylhtsang for example, if batch_size of user_tower is 2, batch_size of item_tower is 4. user tower and item tower do not need share embedding. we could create

kt_u = KeyedJaggedTensor(
    keys=['user_f'],
    values=torch.tensor([0,1]),
    lengths=torch.tensor([1,1], dtype=torch.int64),
)
kt_i = KeyedJaggedTensor(
    keys=['item_f'],
    values=torch.tensor([2,3,4,5]),
    lengths=torch.tensor([1,1,1,1], dtype=torch.int64),
)
eb_config_u = [
    EmbeddingBagConfig(
        num_embeddings=100,
        embedding_dim=16,
        name='e1',
        feature_names=['user_f']
    )
]
eb_config_i = [
    EmbeddingBagConfig(
        num_embeddings=100,
        embedding_dim=16,
        name='e2',
        feature_names=['item_f']
    )
]
ebc_u = EmbeddingBagCollection(eb_config_u)
ebc_i = EmbeddingBagCollection(eb_config_i)
print('user:', ebc_u(kt_u).values().shape)
print('item:', ebc_i(kt_i).values().shape)
user: torch.Size([2, 16])
item: torch.Size([4, 16])

If we use VBE to implement share-embedding, the output of the user tower and item tower will be padded to the same batch_size.

kt = KeyedJaggedTensor(
    keys=['user_f', 'item_f'],
    values=torch.tensor([0,1,2,3,4,5]),
    lengths=torch.tensor([1,1,1,1,1,1], dtype=torch.int64),
    stride_per_key_per_rank=[[2], [4]],
    inverse_indices=(['user_f', 'item_f'], torch.tensor([[0,1,1,1], [0,1,2,3]]))

)
eb_configs = [
    EmbeddingBagConfig(
        num_embeddings=100,
        embedding_dim=16,
        name='e1',
        feature_names=['user_f', 'item_f']
    )
]
ebc = EmbeddingBagCollection(eb_configs)
print('user+item:', ebc(kt).values().shape)
user+item: torch.Size([4, 32])

batch_size of user tower is 4 rather than 2.

henrylhtsang commented 4 months ago

not an expert, but can you try the sharded version of ebc? not sure if the unsharded ebc supports VBE very well