pytorch / torchrec

Pytorch domain library for recommendation systems
https://pytorch.org/torchrec/
BSD 3-Clause "New" or "Revised" License
1.88k stars 408 forks source link

quantize_embeddings + KeyedJaggedTensor+ vbe cannot work #1894

Open yjjinjie opened 5 months ago

yjjinjie commented 5 months ago

import torch
from torchrec import KeyedJaggedTensor
from torchrec import EmbeddingBagConfig,EmbeddingConfig
from torchrec import EmbeddingBagCollection,EmbeddingCollection

kt2 = KeyedJaggedTensor(
    keys=['user_id', 'item_id', 'id_3', 'id_4', 'id_5', 'raw_1', 'raw_4', 'combo_1', 'lookup_2', 'lookup_3', 'lookup_4', 'match_2', 'match_3', 'match_4', 'click_50_seq__item_id', 'click_50_seq__id_3', 'click_50_seq__raw_1'], 
    values=torch.tensor([573174,   5073,   3562,      3,     18,     13,     11,     49,     26,
             4,      2,      2,      4,      2,      4, 736847, 849333, 997432,
        640218,   9926,   9926,      0,      0,      0,      0,  59926,  59926,
             0,      0,      0,      0,   2835,    769,   1265,   8232,   6399,
           114,   7487,   2876,    953,   7840,   7538,   7998,   7852,   3528,
          1475,   7620,   6110,    572,    735,   4405,   5655,   6736,   2173,
          3421,   2311,   7122,   2159,   4535,   2162,   4657,   3151,   4522,
          1075,    306,   8968,   2056,   2256,   3919,   8624,   5372,   6018,
          3861,   4114,   3984,   2287,   1481,   4757,   1189,   2518,    913,
          9421,   3093,   5911,   9704,   8168,   9410,    728,   2451,    243,
          5187,   5836,   8830,   4894,    614,   7705,   9258,   3518,   4434,
             4,      2,      4,      2,      4,      2,      3,      2,      2,
             3,      3,      3,      4,      4,      3,      0,      4,      0,
             2,      2,      3,      4,      4,      0,      2,      2,      4,
             0,      3,      2,      2,      3,      0,      4,      0,      4,
             4,      4,      2,      2,      3,      4,      2,      4,      3,
             4,      2,      4,      2,      2,      2,      2,      0,      3,
             4,      4,      3,      2,      4,      4,      4,      4,      3,
             2,      3,      4,      2,      4,      0,      4,      4,      4,
             4,      0,      0,      2,      1,      1,      0,      3,      4,
             4,      2,      4,      1,      1,      4,      2,      2,      4,
             0,      4,      4,      4,      4,      4,      1,      4,      2,
             0,      0,      0,      2,      4,      4,      2,      4,      2,
             4,      4,      1,      1,      4,      1,      4,      4,      1,
             0,      4,      4,      4,      3,      0,      0,      2,      4,
             2,      2,      4,      4,      4,      2,      2,      4,      2,
             3]),
    lengths=torch.tensor([ 1,  1,  1,  1,  0,  0,  1,  2,  2,  1,  1,  4,  2,  2,  1,  1,  1,  1,
         1,  1,  1,  1,  1,  1,  1,  1, 24, 44, 24, 44, 24, 44], dtype=torch.int64),
    stride_per_key_per_rank=[[1], [2], [2], [2], [2], [2], [1], [2], [2], [2], [2], [2], [2], [2], [2], [2], [2]],
    inverse_indices=(['user_id', 'item_id', 'id_3', 'id_4', 'id_5', 'raw_1', 'raw_4', 'combo_1', 'lookup_2', 'lookup_3', 
                      'lookup_4', 'match_2', 'match_3', 'match_4', 'click_50_seq__item_id', 'click_50_seq__id_3', 
                      'click_50_seq__raw_1'], 
                     torch.tensor([[0, 0], [0, 1],[0, 1], [0, 1], [0, 1], [0, 1],[0, 0], [0, 1], [0, 1], [0, 1],
                                   [0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1]])
    )
)

eb_configs2=[
    EmbeddingBagConfig(num_embeddings=1000000, embedding_dim=16, name='user_id_emb', feature_names=['user_id'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None,  need_pos=False, ),
EmbeddingBagConfig(num_embeddings=10000, embedding_dim=16, name='item_id_emb', feature_names=['item_id'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None,  need_pos=False, ),
EmbeddingBagConfig(num_embeddings=5, embedding_dim=8, name='id_3_emb', feature_names=['id_3'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None, need_pos=False, ),
EmbeddingBagConfig(num_embeddings=100, embedding_dim=16, name='id_4_emb', feature_names=['id_4', 'id_5'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None,  need_pos=False, ),
EmbeddingBagConfig(num_embeddings=5, embedding_dim=16, name='raw_1_emb', feature_names=['raw_1'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None,  need_pos=False, ),
EmbeddingBagConfig(num_embeddings=5, embedding_dim=16, name='raw_4_emb', feature_names=['raw_4'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None,  need_pos=False, ),
EmbeddingBagConfig(num_embeddings=1000000, embedding_dim=16, name='combo_1_emb', feature_names=['combo_1'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None,  need_pos=False, ),
EmbeddingBagConfig(num_embeddings=10000, embedding_dim=8, name='lookup_2_emb', feature_names=['lookup_2'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None,   need_pos=False, ),
EmbeddingBagConfig(num_embeddings=1000, embedding_dim=8, name='lookup_3_emb', feature_names=['lookup_3'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None,  need_pos=False, ),
EmbeddingBagConfig(num_embeddings=5, embedding_dim=16, name='lookup_4_emb', feature_names=['lookup_4'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None,  need_pos=False, ),
EmbeddingBagConfig(num_embeddings=100000, embedding_dim=16, name='match_2_emb', feature_names=['match_2'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None,   need_pos=False, ),
EmbeddingBagConfig(num_embeddings=10000, embedding_dim=8, name='match_3_emb', feature_names=['match_3'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None,   need_pos=False, ),
EmbeddingBagConfig(num_embeddings=5, embedding_dim=16, name='match_4_emb', feature_names=['match_4'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None,  need_pos=False, ),

]
ebc = EmbeddingBagCollection(eb_configs2)

print(ebc(kt2))
from torchrec.inference.modules import quantize_embeddings

import torch
import torch.nn as nn

class EmbeddingGroupImpl(nn.Module):
    def __init__(self,ebc):
        super().__init__()
        self.ebc=ebc

    def forward(
        self,
        sparse_feature
    ):
        self.ebc(sparse_feature)

a=EmbeddingGroupImpl(ebc=ebc)
a.forward(kt2)

quant_model = quantize_embeddings(a, dtype=torch.qint8, inplace=True)
print(quant_model(kt2))

报错:

Traceback (most recent call last):
  File "/larec/tzrec/tests/test_per2.py", line 89, in <module>
    print(quant_model(kt2))
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/larec/tzrec/tests/test_per2.py", line 83, in forward
    self.ebc(sparse_feature)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torchrec/quant/embedding_modules.py", line 487, in forward
    else emb_op.forward(
  File "/opt/conda/lib/python3.10/site-packages/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py", line 764, in forward
    torch.ops.fbgemm.bounds_check_indices(
  File "/opt/conda/lib/python3.10/site-packages/torch/_ops.py", line 758, in __call__
    return self._op(*args, **(kwargs or {}))
RuntimeError: offsets size 27 is not equal to B (1) * T (14) + 1
yjjinjie commented 5 months ago

@henrylhtsang please see this problem

PaulZhang12 commented 4 months ago

I don't believe VBE + quantized EBC is yet supported. Quantized EBC uses a completely different FBGEMM TBE than the standard EBC for training

yjjinjie commented 4 months ago

can you support VBE + quantized EBC for inference?