NVIDIA-Merlin / Transformers4Rec

Transformers4Rec is a flexible and efficient library for sequential and session-based recommendation and works with PyTorch.
https://nvidia-merlin.github.io/Transformers4Rec/main
Apache License 2.0
1.07k stars 142 forks source link

[BUG] Inconsistent inference and evaluation results of the XLNET-CLM even on the training set! #761

Open dcy0577 opened 8 months ago

dcy0577 commented 8 months ago

Bug description

Hello, I followed the example and successfully trained an XLNet-CLM model on my custom dataset. However, I noticed that while the model performs well on the validation set with trainer.evaluate() (even achieving 90% recall@5), I encountered many errors when using trainer.predict() for inference, falling far short of the expected performance. So, I conducted an experiment: I took a portion of the training set data and input it into both functions. I used sequence[:] for evaluate() and sequence[:-1] for predict():

=========data for eval===============
   session_id                                      item_id-list
0           1          [26, 26, 26, 26, 4, 4, 4, 4, 4, 4, 4, 4]
1           2       [7, 43, 35, 3, 3, 7, 29, 35, 35, 111, 5, 9]
2           3       [74, 7, 74, 7, 7, 110, 32, 67, 4, 4, 17, 7]
4           5   [56, 25, 25, 25, 25, 25, 25, 23, 34, 4, 19, 43]
5           6  [270, 41, 41, 41, 41, 7, 43, 34, 78, 38, 71, 23]
6           7        [74, 28, 28, 5, 5, 24, 9, 5, 5, 59, 4, 91]
=========data for infer===============
   session_id                                  item_id-list
0           1         [26, 26, 26, 26, 4, 4, 4, 4, 4, 4, 4]
1           2      [7, 43, 35, 3, 3, 7, 29, 35, 35, 111, 5]
2           3      [74, 7, 74, 7, 7, 110, 32, 67, 4, 4, 17]
4           5   [56, 25, 25, 25, 25, 25, 25, 23, 34, 4, 19]
5           6  [270, 41, 41, 41, 41, 7, 43, 34, 78, 38, 71]
6           7        [74, 28, 28, 5, 5, 24, 9, 5, 5, 59, 4]
=========labels:===============
0     4
1     9
2     7
4    43
5    23
6    91
Name: item_id-list, dtype: int64

Ideally, the predictions from both functions should be similar, but it seems the inference results is significantly worse:

=========inference===============
PredictionOutput(predictions=(array([[ 4, 19,  7,  5, 11],
       [ 5, 15, 23, 22,  7],
       [17,  7, 30, 70, 15],
       [19,  4, 11,  5, 15],
       [ 7,  4, 79, 34,  6],
       [ 4,  7, 19, 11,  5]]), array([[7.677282 , 5.3613596, 5.00848  , 4.6888046, 4.319791 ],
       [7.15173  , 5.7525525, 5.726646 , 4.9717607, 4.903692 ],
       [7.843027 , 7.171323 , 5.413306 , 5.378131 , 5.3367157],
       [6.309387 , 5.833148 , 5.6351004, 4.9878273, 4.4770455],
       [7.6406755, 5.9271154, 5.91128  , 5.8692527, 5.589204 ],
       [7.7423315, 5.4810033, 5.0784097, 4.7432585, 4.603503 ]],
      dtype=float32)), label_ids=None, metrics={'test_runtime': 0.792, 'test_samples_per_second': 7.576, 'test_steps_per_second': 3.788})
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 120.82it/s]
=========evaluation===============
PredictionOutput(predictions=(array([[  4,  12,  19,   7,  36],
       [  9,  12,  10,  33,  49],
       [  7,  34,   4,  30,  79],
       [ 43, 110, 129,   6, 126],
       [ 23,  34,  15,  56,  42],
       [ 91,  59, 102, 121,  83]]), array([[18.594198 , 10.1901245, 10.138629 ,  9.487766 ,  9.470682 ],
       [18.095396 , 11.1290245, 10.906204 , 10.659245 , 10.633735 ],
       [16.824923 ,  9.672426 ,  9.347423 ,  8.993546 ,  8.8174095],
       [15.070214 , 10.285906 ,  9.3903475,  8.776609 ,  8.753696 ],
       [16.2281   , 10.282413 , 10.155615 ,  9.619689 ,  9.48603  ],
       [14.316702 , 10.41172  ,  9.787606 ,  9.7606325,  9.695224 ]],
      dtype=float32)), label_ids=array([ 4,  9,  7, 43, 23, 91]), metrics={'eval_/next-item/ndcg_at_5': 1.0, 'eval_/next-item/ndcg_at_10': 1.0, 'eval_/next-item/recall_at_5': 1.0, 'eval_/next-item/recall_at_10': 1.0, 'eval_/next-item/avg_precision_at_5': 1.0, 'eval_/next-item/avg_precision_at_10': 1.0, 'eval_/loss': 0.02894706465303898, 'eval_runtime': 0.0639, 'eval_samples_per_second': 93.926, 'eval_steps_per_second': 46.963})

Please note that I conducted this experiment on the training set. The outputs from evaluation are expected, but the inference not. I'm curious to know why this is happening. Thanks!

Here is my code that generates the outputs above:

  tr_model.load_state_dict(torch.load("tmp/checkpoint-450/pytorch_model.bin"))
  tr_model.eval()

  args = tr.trainer.T4RecTrainingArguments(
          output_dir="tmp",
          per_device_eval_batch_size=2,
          max_sequence_length=30,
          fp16=True,
      )

  trainer = tr.Trainer(
      model=tr_model,
      args=args,
      schema=schema,
      compute_metrics=True,
      )

  trainer.test_dataset_or_path = 'data/preproc_sessions_by_day_3827/1/train_truncate_for_infer.parquet'
  trainer.eval_dataset_or_path = 'data/preproc_sessions_by_day_3827/1/train_truncate_for_eval.parquet'
  trainer.args.predict_top_k = 5

  prediction = trainer.predict(trainer.test_dataset_or_path)
  print("=========inference===============")
  print(prediction)

  # a small monkey patch to output the predictions, not only the metrics
  tr.Trainer.evaluate = evaluate_change_output
  prediction = trainer.evaluate()
  print("=========evaluation===============")
  print(prediction)

Environment details

dcy0577 commented 8 months ago

I've reviewed the clm masking code, and I'm little confused about this line here: https://github.com/NVIDIA-Merlin/Transformers4Rec/blob/625897c3f55135342dbe53c32b7650a8ecb86c75/transformers4rec/torch/masking.py#L319 I would like to know, what is the purpose of removing the last padding item if the input is padded? Should the last item be removed instead of the feature of the last padding item?

dcy0577 commented 8 months ago

anyway, I changed some code in apply_mask_to_inputs of clm masking:

def apply_mask_to_inputs_CLM(
        self,
        inputs: torch.Tensor,
        mask_schema: torch.Tensor,
        training: bool = False,
        testing: bool = False,
    ) -> torch.Tensor:
        if not training and not testing:
            # Replacing the inputs corresponding to padded items with a trainable embedding
            # To mimic training and evaluation masking strategy
            inputs = torch.where(
                mask_schema.unsqueeze(-1).bool(),
                inputs,
                self.masked_item_embedding.to(inputs.dtype),
            )
            return inputs

        # # shift sequence of interaction embeddings
        # pos_emb_inp = inputs[:, :-1]
        # # Adding a masked item in the sequence to return to the initial sequence.
        # pos_emb_inp = torch.cat(  # type: ignore
        #     [
        #         pos_emb_inp,
        #         torch.zeros(
        #             (pos_emb_inp.shape[0], 1, pos_emb_inp.shape[2]),
        #             dtype=pos_emb_inp.dtype,
        #         ).to(inputs.device),
        #     ],
        #     axis=1,
        # )

        pos_emb_inp = inputs
        pos_emb_inp_new = pos_emb_inp.clone()
        # Iterate over each row in the boolean tensor
        for i in range(mask_schema.shape[0]):
            # Find the index of the last True value in the row
            # If there's no True value, idx will be -1
            idx = (mask_schema[i].nonzero(as_tuple=True)[0]).max().item() if mask_schema[i].any() else -1
            # Replace corresponding item in feature tensor with a zero matrix
            if idx != -1:
                pos_emb_inp_new[i, idx] = torch.zeros(pos_emb_inp.shape[2], dtype=pos_emb_inp.dtype).to(inputs.device)

        pos_emb_inp = pos_emb_inp_new
        # Replacing the inputs corresponding to padded items with a trainable embedding
        pos_emb_inp = torch.where(
            mask_schema.unsqueeze(-1).bool(),
            pos_emb_inp,
            self.masked_item_embedding.to(pos_emb_inp.dtype),
        )
        return pos_emb_inp

Interestingly, with this modification, the metrics of XLNet in the CLM setting have decreased compared to before, making it more reasonable. I've also noticed that this #719 and #746 mentioned a similar issue. Additionally, I observed that the outputs of the predict and evaluate functions have become similar:

# same inputs as before!
=========inference===============
PredictionOutput(predictions=(array([[ 4,  3,  7,  5,  6],
       [ 5,  7, 30,  3, 22],
       [17,  7,  6, 26, 11],
       [ 3,  4, 25,  7, 18],
       [71, 26, 24,  4,  3],
       [ 4,  5,  3,  7, 22]]), array([[ 8.456369 ,  5.8312187,  5.6498675,  5.1875997,  4.956415 ],
       [10.017425 ,  6.9653053,  6.9261403,  6.617892 ,  6.447962 ],
       [ 9.492199 ,  8.0066   ,  6.5381365,  6.097307 ,  5.9652367],
       [ 6.56517  ,  6.2221594,  5.269926 ,  5.1904283,  5.092497 ],
       [ 6.0176606,  5.6072693,  5.431714 ,  5.273334 ,  5.039026 ],
       [ 8.518238 ,  6.537841 ,  5.880717 ,  5.6688414,  5.0900187]],
      dtype=float32)), label_ids=None, metrics={'test_runtime': 2.0936, 'test_samples_per_second': 2.866, 'test_steps_per_second': 1.433})
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 14.08it/s]
=========evaluation===============
PredictionOutput(predictions=(array([[ 4,  7,  3,  6, 30],
       [ 5,  7, 30,  3,  4],
       [17,  7,  6, 26,  4],
       [ 3,  4, 25, 39,  7],
       [ 4, 24, 71, 26, 14],
       [ 4, 30,  5,  3,  7]]), array([[7.4311004, 4.9409456, 4.876472 , 4.848051 , 4.7532825],
       [7.4449983, 6.3823347, 6.3242016, 6.160487 , 5.6924496],
       [7.9830294, 7.4659915, 6.0589595, 5.7209487, 5.57537  ],
       [6.721603 , 5.4623756, 5.102635 , 4.9805446, 4.9469085],
       [4.9136906, 4.843173 , 4.796401 , 4.7077065, 4.5187464],
       [7.7905493, 5.1441355, 5.0995426, 5.0602183, 5.031133 ]],
      dtype=float32)), label_ids=array([ 4,  9,  7, 43, 23, 91]), metrics={'eval_/next-item/ndcg_at_5': 0.27182161808013916, 'eval_/next-item/ndcg_at_10': 0.27182161808013916, 'eval_/next-item/recall_at_5': 0.3333333432674408, 'eval_/next-item/recall_at_10': 0.3333333432674408, 'eval_/next-item/avg_precision_at_5': 0.25, 'eval_/next-item/avg_precision_at_10': 0.25, 'eval_/loss': 4.302186489105225, 'eval_runtime': 0.6978, 'eval_samples_per_second': 8.599, 'eval_steps_per_second': 4.299})

I'm not sure if my changes are correct, and I strongly recommend that you pay attention to this issue. Thanks!