RUCAIBox / RecBole

A unified, comprehensive and efficient recommendation library
https://recbole.io/
MIT License
3.47k stars 615 forks source link

DSSM模型報錯 #1583

Open emilyjeng opened 1 year ago

emilyjeng commented 1 year ago

您好,我在使用yelp2022在DSSM模型時,會產生以下錯誤:

Traceback (most recent call last): File "run_recbole.py", line 48, in run_recbole( File "/Emily/RecBole-master/recbole/quick_start/quick_start.py", line 81, in run_recbole flops = get_flops(model, dataset, config["device"], logger, transform) File "/Emily/RecBole-master/recbole/utils/utils.py", line 345, in get_flops wrapper(inputs) File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1111, in _call_impl return forward_call(input, kwargs) File "/Emily/RecBole-master/recbole/utils/utils.py", line 286, in forward return self.model.predict(interaction) File "/Emily/RecBole-master/recbole/model/context_aware_recommender/dssm.py", line 111, in predict return self.sigmoid(self.forward(interaction)) File "/Emily/RecBole-master/recbole/model/context_aware_recommender/dssm.py", line 100, in forward user_dnn_out = self.user_mlp_layers(embed_user.view(batch_size, -1)) File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1111, in _call_impl return forward_call(*input, *kwargs) File "/Emily/RecBole-master/recbole/model/layers.py", line 89, in forward return self.mlp_layers(input_feature) File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1129, in _call_impl result = forward_call(input, kwargs) File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/container.py", line 139, in forward input = module(input) File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1129, in _call_impl result = forward_call(*input, **kwargs) File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 103, in forward return F.linear(input, self.weight, self.bias) RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x20 and 30x256)


我的設定:

context-awar

field_separator: "\t" seq_separator: " " USER_ID_FIELD: user_id ITEM_ID_FIELD: item_id RATING_FIELD: rating TIME_FIELD: timestamp NEGPREFIX: neg LABEL_FIELD: label normalize_all: True #正規化 threshold: rating: 4 load_col: inter: ['user_id', 'item_id', 'rating', 'timestamp'] user: ['user_id', 'user_name', 'user_useful'] #, 'user_review_count', 'timestamp', categories item: ['item_id', 'item_name','city', 'item_stars', 'categories']

user_inter_num_interval: "[13,inf)" item_inter_num_interval: "[10,inf)" val_interval: rating: "[4,inf)"

model config

embedding_size: 10 # (int) The embedding size of features. mlp_hidden_size: [256, 256, 256] # (list of int) The hidden size of MLP layers. dropout_prob: 0.3 # (float) The dropout rate of edge in the linear predict layer. double_tower: True # (bool) Whether or not to use the double-tower mode.

epochs: 500 train_batch_size: 4096 eval_batch_size: 4096 metrics: ['Recall', 'MRR', 'NDCG', 'Hit', 'Precision'] valid_metric: Hit@10 loss_type: 'BPR' # (str) The type of loss function. Range in ['BPR'].

eval_args: split: {'RS':[0.8,0.1,0.1]} group_by: user #是否将一个user的记录划到一个组里,当eval_setting使用RO_RS的时候该项必须是True order: TO mode: full

disable negative sampling

train_neg_sample_args: {'distribution': 'uniform', 'sample_num': 1} #BPR

save_dataset: Ture save_dataloaders: Ture dataset_save_path: Ture dataloaders_save_path: Ture

Ethan-TZ commented 1 year ago

@emilyjeng 您好,感谢您的关注! 由于我们提供的yelp数据集一般是没有.user和.item文件的,请问您是否方便提供一小部分您使用的数据集以便我们测试呢?

emilyjeng commented 1 year ago

@chenyuwuxin 您好,我從您們提供的google雲端中下載yelp2022檔案,且裡面有.user和.item文件

image
Ethan-TZ commented 1 year ago

@emilyjeng 您好,我們剛剛修復了這個問題,麻煩您重新嘗試。

emilyjeng commented 1 year ago

@chenyuwuxin 您好!我重新下載recbole,但狀況還是一樣,是否要重新下載數據集?