Closed hekun24 closed 2 years ago
@hekun24 您好,感谢您对 RecBole 的关注!
首先,RecBole 的最新版本对 user_inter_num_interval
等参数的设置进行了细微的调整,从输出来看您目前使用的是以前的版本,可以尝试先进行更新。
其次,您的设置中并没有对交互数和物品数进行限制,但输出的 ml-1m 数据集结果却是 3417 个物品和 999611 条交互,这与 RecBole 提供的 ml-1m 数据集结果不太相符,物品数量的细节有些出入:
The number of users: 6041
Average actions of users: 165.5975165562914
The number of items: 3707
Average actions of items: 269.88909875876953
The number of inters: 1000209
The sparsity of the dataset: 95.53358229599758%
最后,也是最重要的,当我们比较两个模型在某个数据集上的表现时,不能只基于两次实验结果就得出结论,而需针对数据集调节模型的各个超参数,使得每个模型都达到最优的结果才能进行比较。
关于您提出的 GRU4Rec 在 ml-1m 上效果远远好于 SASRec 的问题,是因为默认的超参数设置下 SASRec 未能在 ml-1m 表现出最优的性能。比如,参数 hidden_dropout_prob
和 attn_dropout_prob
的默认值均为 0.5,但 0.5 并非 SASRec 模型在 ml-1m 数据集上的最优参数。根据我们团队的测试,在稠密的 ml-1m 数据集上设置 hidden_dropout_prob=attn_dropout_prob=0.2
可以显著提升 SASRec 的推荐结果。在我们较为公平的超参数调节和测试下,SASRec 优于 GRU4Rec 模型。
了解,感谢你们的耐心回答
描述这个 bug 对 bug 作一个清晰简明的描述。
如何复现 复现这个 bug 的步骤:
GRU4Rec配置 Sat 14 May 2022 16:25:09 INFO
General Hyper Parameters: gpu_id = 0 use_gpu = True seed = 2020 state = INFO reproducibility = True data_path = dataset/ml-1m checkpoint_dir = saved show_progress = True save_dataset = False dataset_save_path = None save_dataloaders = False dataloaders_save_path = None log_wandb = False
Training Hyper Parameters: epochs = 300 train_batch_size = 4096 learner = adam learning_rate = 0.001 neg_sampling = None eval_step = 1 stopping_step = 10 clip_grad_norm = None weight_decay = 0.0 loss_decimal_place = 4
Evaluation Hyper Parameters: eval_args = {'split': {'LS': 'valid_and_test'}, 'order': 'TO', 'mode': 'full', 'group_by': 'user'} repeatable = True metrics = ['Recall', 'MRR', 'NDCG', 'Hit', 'Precision'] topk = [10] valid_metric = MRR@10 valid_metric_bigger = True eval_batch_size = 4096 metric_decimal_place = 4
Dataset Hyper Parameters: field_separator =
seq_separator =
USER_ID_FIELD = user_id ITEM_ID_FIELD = item_id RATING_FIELD = rating TIME_FIELD = timestamp seq_len = None LABEL_FIELD = label threshold = None NEGPREFIX = neg load_col = {'inter': ['user_id', 'item_id', 'rating', 'timestamp']} unload_col = None unused_col = None additional_feat_suffix = None rm_dup_inter = None val_interval = None filter_inter_by_user_or_item = True user_inter_num_interval = None item_inter_num_interval = None alias_of_user_id = None alias_of_item_id = None alias_of_entity_id = None alias_of_relation_id = None preload_weight = None normalize_field = None normalize_all = None ITEM_LIST_LENGTH_FIELD = item_length LIST_SUFFIX = _list MAX_ITEM_LIST_LENGTH = 50 POSITION_FIELD = position_id HEAD_ENTITY_ID_FIELD = head_id TAIL_ENTITY_ID_FIELD = tail_id RELATION_ID_FIELD = relation_id ENTITY_ID_FIELD = entity_id benchmark_filename = None
Other Hyper Parameters: wandb_project = recbole require_pow = False embedding_size = 64 hidden_size = 128 num_layers = 1 dropout_prob = 0.3 loss_type = CE MODEL_TYPE = ModelType.SEQUENTIAL MODEL_INPUT_TYPE = InputType.POINTWISE eval_type = EvaluatorType.RANKING device = cuda train_neg_sample_args = {'strategy': 'none'} eval_neg_sample_args = {'strategy': 'full', 'distribution': 'uniform'}
Sat 14 May 2022 16:25:14 INFO ml-1m The number of users: 6041 Average actions of users: 165.49850993377484 The number of items: 3417 Average actions of items: 292.6261709601874 The number of inters: 999611 The sparsity of the dataset: 95.15741545057172% Remain Fields: ['user_id', 'item_id', 'rating', 'timestamp']
SASRec配置 Sat 14 May 2022 20:19:09 INFO
General Hyper Parameters: gpu_id = 0 use_gpu = True seed = 2020 state = INFO reproducibility = True data_path = dataset/ml-1m checkpoint_dir = saved show_progress = True save_dataset = False dataset_save_path = None save_dataloaders = False dataloaders_save_path = None log_wandb = False
Training Hyper Parameters: epochs = 300 train_batch_size = 4096 learner = adam learning_rate = 0.001 neg_sampling = None eval_step = 1 stopping_step = 10 clip_grad_norm = None weight_decay = 0.0 loss_decimal_place = 4
Evaluation Hyper Parameters: eval_args = {'split': {'LS': 'valid_and_test'}, 'order': 'TO', 'mode': 'full', 'group_by': 'user'} repeatable = True metrics = ['Recall', 'MRR', 'NDCG', 'Hit', 'Precision'] topk = [10] valid_metric = NDCG@10 valid_metric_bigger = True eval_batch_size = 4096 metric_decimal_place = 4
Dataset Hyper Parameters: field_separator =
seq_separator =
USER_ID_FIELD = user_id ITEM_ID_FIELD = item_id RATING_FIELD = rating TIME_FIELD = timestamp seq_len = None LABEL_FIELD = label threshold = None NEGPREFIX = neg load_col = {'inter': ['user_id', 'item_id', 'rating', 'timestamp']} unload_col = None unused_col = None additional_feat_suffix = None rm_dup_inter = None val_interval = None filter_inter_by_user_or_item = True user_inter_num_interval = None item_inter_num_interval = None alias_of_user_id = None alias_of_item_id = None alias_of_entity_id = None alias_of_relation_id = None preload_weight = None normalize_field = None normalize_all = None ITEM_LIST_LENGTH_FIELD = item_length LIST_SUFFIX = _list MAX_ITEM_LIST_LENGTH = 50 POSITION_FIELD = position_id HEAD_ENTITY_ID_FIELD = head_id TAIL_ENTITY_ID_FIELD = tail_id RELATION_ID_FIELD = relation_id ENTITY_ID_FIELD = entity_id benchmark_filename = None
Other Hyper Parameters: wandb_project = recbole require_pow = False n_layers = 2 n_heads = 2 hidden_size = 64 inner_size = 256 hidden_dropout_prob = 0.5 attn_dropout_prob = 0.5 hidden_act = gelu layer_norm_eps = 1e-12 initializer_range = 0.02 loss_type = CE MODEL_TYPE = ModelType.SEQUENTIAL MODEL_INPUT_TYPE = InputType.POINTWISE eval_type = EvaluatorType.RANKING device = cuda train_neg_sample_args = {'strategy': 'none'} eval_neg_sample_args = {'strategy': 'full', 'distribution': 'uniform'}
Sat 14 May 2022 20:19:14 INFO ml-1m The number of users: 6041 Average actions of users: 165.49850993377484 The number of items: 3417 Average actions of items: 292.6261709601874 The number of inters: 999611 The sparsity of the dataset: 95.15741545057172% Remain Fields: ['user_id', 'item_id', 'rating', 'timestamp']
预期 一般情况下,SASRec的结果是要远远好于GRURec,但是我搞了很多次,无论是按留一法还是按比例,GRU4Rec都好于SASRec,在ml-100k也有同样的问题
GRU4Rec结果
SASRec结果
屏幕截图
实验环境(请补全下列信息):