HKUDS / KGRec

[KDD'2023] "KGRec: Knowledge Graph Self-Supervised Rationalization for Recommendation"
https://arxiv.org/abs/2307.02759
Apache License 2.0
55 stars 3 forks source link

模型表現 #5

Closed paul0728 closed 1 year ago

paul0728 commented 1 year ago

作者您好,我嘗試跑了您提供的所有程式碼,想請教以下問題: 1. 此表為論文中的數據 image 此表為我以程式碼所跑出來的數據 image 在一些數據上差異甚大,kgrec在mind dataset上面也輸給kgcl 我完全照預設設定執行程式,只有改變batch size而已 2. image 上面寫說early stopping at 10, recall@20:0.0349 但在表格當中recall卻是0.02996574 為什麼會有不一致的情形?

謝謝

yuh-yang commented 1 year ago

你好,

  1. 请问你在1中设置的batch_size为多少?看起来MIND对于batch_size很敏感

  2. early stop at 10,说明best performance在10-early_stop_patience之前噢,看来这个数据集上第一个epoch就已经是最好表现,有点过拟合,我在实验里没遇到过这个现象,可以提供一下超参数列表吗

paul0728 commented 1 year ago

您好, 1.我沒有固定batch size,印象中是用2048,4096或是8192進行實驗的(由於我是全部實驗一起進行,所以都是看是否有足夠mem再做調整) 2.了解,所以沒有紀錄最好的表現的其他數據嗎?像是ndcg,precision,hit ratio。 我的超參數就是按照預設值設定的

yuh-yang commented 1 year ago
  1. 可以尝试给MIND数据集用默认的batch_size吗?这边是一个reproduce的log,可以参考下
  2. 可以直接看最好的那个epoch的eval数据,会有记录的
    
    PID: 2043994
    DESC: 

########## Ablation ########## ablation: None ########## Model HPs ########## tau: 0.1 cL_drop: 0.6 cl_coef: 0.001 mae_coef: 0.1 ########## Model Parameters ########## context_hops: 2 node_dropout: 1 node_dropout_rate: 0.5 mess_dropout: 1 mess_dropout_rate: 0.1 all_embed: torch.Size([149545, 64]) interact_mat: torch.Size([2, 2035114]) edge_index: torch.Size([2, 297058]) edge_type: torch.Size([297058]) start training ... neg_sampling_cpp time: 2.93s train_cf_triples shape: (2035114, 3) edge_attn_score std: 4.536376536634634e-07 2022-12-15 08:59:41: using time 981.4958229064941, training loss at epoch 0: [858.2809901833534, 136.85224232822657, 2.1338425274007022] neg_sampling_cpp time: 2.92s train_cf_triples shape: (2035114, 3) edge_attn_score std: 3.181440661137458e-06 +-------+-------------------+--------------------+----------------------------------------------------------+--------------+--------------+-------------+--------------+ | Epoch | training time | tesing time | Loss | recall | ndcg | precision | hit_ratio | +-------+-------------------+--------------------+----------------------------------------------------------+--------------+--------------+-------------+--------------+ | 1 | 983.1496224403381 | 101.02993535995483 | [663.2253972291946, 135.91076271981, 1.6955093519645743] | [0.02838848] | [0.01867379] | [0.0095401] | [0.16800168] | +-------+-------------------+--------------------+----------------------------------------------------------+--------------+--------------+-------------+--------------+ neg_sampling_cpp time: 3.21s train_cf_triples shape: (2035114, 3) edge_attn_score std: 3.2832700526341796e-05 +-------+--------------------+-------------------+------------------------------------------------------------+--------------+-------------+--------------+--------------+ | Epoch | training time | tesing time | Loss | recall | ndcg | precision | hit_ratio | +-------+--------------------+-------------------+------------------------------------------------------------+--------------+-------------+--------------+--------------+ | 2 | 1091.8222153186798 | 82.32155442237854 | [539.152124479413, 135.60504004359245, 1.6517149189021438] | [0.02555022] | [0.0163434] | [0.00845308] | [0.15277153] | +-------+--------------------+-------------------+------------------------------------------------------------+--------------+-------------+--------------+--------------+ neg_sampling_cpp time: 2.92s train_cf_triples shape: (2035114, 3) edge_attn_score std: 0.00014418832142837346 +-------+--------------------+-------------------+-----------------------------------------------------------+--------------+--------------+--------------+--------------+ | Epoch | training time | tesing time | Loss | recall | ndcg | precision | hit_ratio | +-------+--------------------+-------------------+-----------------------------------------------------------+--------------+--------------+--------------+--------------+ | 3 | 1058.4946956634521 | 85.32225203514099 | [445.1733414977789, 135.3312196061015, 1.621447007113602] | [0.02434268] | [0.01500672] | [0.00803058] | [0.14701147] | +-------+--------------------+-------------------+-----------------------------------------------------------+--------------+--------------+--------------+--------------+ neg_sampling_cpp time: 2.94s train_cf_triples shape: (2035114, 3) edge_attn_score std: 0.0004245133022777736 +-------+--------------------+-------------------+-------------------------------------------------------------+--------------+--------------+--------------+--------------+ | Epoch | training time | tesing time | Loss | recall | ndcg | precision | hit_ratio | +-------+--------------------+-------------------+-------------------------------------------------------------+--------------+--------------+--------------+--------------+ | 4 | 1013.2886476516724 | 85.05164742469788 | [386.4210756570101, 134.97420116513968, 1.5945173841319047] | [0.02432741] | [0.01490561] | [0.00802958] | [0.14643146] | +-------+--------------------+-------------------+-------------------------------------------------------------+--------------+--------------+--------------+--------------+ neg_sampling_cpp time: 3.11s train_cf_triples shape: (2035114, 3) edge_attn_score std: 0.0008968916372396052 +-------+--------------------+-------------------+--------------------------------------------------------------+--------------+--------------+-------------+-------------+ | Epoch | training time | tesing time | Loss | recall | ndcg | precision | hit_ratio | +-------+--------------------+-------------------+--------------------------------------------------------------+--------------+--------------+-------------+-------------+ | 5 | 1018.3342461585999 | 81.64792084693909 | [352.78556552529335, 134.33220886439085, 1.5712383096688427] | [0.02977509] | [0.01718717] | [0.0095511] | [0.1701117] | +-------+--------------------+-------------------+--------------------------------------------------------------+--------------+--------------+-------------+-------------+ neg_sampling_cpp time: 2.87s train_cf_triples shape: (2035114, 3) edge_attn_score std: 0.0016492039430886507 +-------+--------------------+-------------------+-----------------------------------------------------------+--------------+-------------+-------------+--------------+ | Epoch | training time | tesing time | Loss | recall | ndcg | precision | hit_ratio | +-------+--------------------+-------------------+-----------------------------------------------------------+--------------+-------------+-------------+--------------+ | 6 | 1016.8289427757263 | 78.96694374084473 | [332.3973051458597, 132.3633598163724, 1.551181080925744] | [0.03213073] | [0.0189466] | [0.0102556] | [0.17893179] | +-------+--------------------+-------------------+-----------------------------------------------------------+--------------+-------------+-------------+--------------+ neg_sampling_cpp time: 2.95s train_cf_triples shape: (2035114, 3) edge_attn_score std: 0.00250235409475863 +-------+-------------------+--------------------+--------------------------------------------------------------+--------------+--------------+--------------+--------------+ | Epoch | training time | tesing time | Loss | recall | ndcg | precision | hit_ratio | +-------+-------------------+--------------------+--------------------------------------------------------------+--------------+--------------+--------------+--------------+ | 7 | 987.1058983802795 | 100.14211750030518 | [319.02492478489876, 127.54480486735702, 1.5329861902282573] | [0.03371124] | [0.01970759] | [0.01094811] | [0.18779188] | +-------+-------------------+--------------------+--------------------------------------------------------------+--------------+--------------+--------------+--------------+ neg_sampling_cpp time: 2.99s train_cf_triples shape: (2035114, 3) edge_attn_score std: 0.0037901264149695635 +-------+-------------------+-------------------+-------------------------------------------------------------+--------------+--------------+--------------+--------------+ | Epoch | training time | tesing time | Loss | recall | ndcg | precision | hit_ratio | +-------+-------------------+-------------------+-------------------------------------------------------------+--------------+--------------+--------------+--------------+ | 8 | 990.9507648944855 | 91.51290035247803 | [310.8137071505189, 120.60236733779311, 1.5170738758752123] | [0.03960177] | [0.02310669] | [0.01328813] | [0.21625216] | +-------+-------------------+-------------------+-------------------------------------------------------------+--------------+--------------+--------------+--------------+ neg_sampling_cpp time: 2.96s train_cf_triples shape: (2035114, 3) edge_attn_score std: 0.007285455707460642 +-------+--------------------+-------------------+-----------------------------------------------------------+--------------+--------------+--------------+--------------+ | Epoch | training time | tesing time | Loss | recall | ndcg | precision | hit_ratio | +-------+--------------------+-------------------+-----------------------------------------------------------+--------------+--------------+--------------+--------------+ | 9 | 1005.7628993988037 | 80.96233057975769 | [304.3987214118242, 112.614363219589, 1.5030067405314185] | [0.04046883] | [0.02521829] | [0.01348613] | [0.21946219] | +-------+--------------------+-------------------+-----------------------------------------------------------+--------------+--------------+--------------+--------------+ neg_sampling_cpp time: 2.94s train_cf_triples shape: (2035114, 3) edge_attn_score std: 0.012098983861505985 +-------+--------------------+-------------------+------------------------------------------------------------+--------------+--------------+--------------+--------------+ | Epoch | training time | tesing time | Loss | recall | ndcg | precision | hit_ratio | +-------+--------------------+-------------------+------------------------------------------------------------+--------------+--------------+--------------+--------------+ | 10 | 1005.2527062892914 | 78.66200017929077 | [298.7598610371351, 104.4885897859931, 1.4908959928434342] | [0.04066323] | [0.02792238] | [0.01363564] | [0.22108221] | +-------+--------------------+-------------------+------------------------------------------------------------+--------------+--------------+--------------+--------------+ neg_sampling_cpp time: 2.94s train_cf_triples shape: (2035114, 3) edge_attn_score std: 0.01825389638543129 +-------+--------------------+-------------------+------------------------------------------------------------+--------------+--------------+--------------+-------------+ | Epoch | training time | tesing time | Loss | recall | ndcg | precision | hit_ratio | +-------+--------------------+-------------------+------------------------------------------------------------+--------------+--------------+--------------+-------------+ | 11 | 1007.1901774406433 | 85.74130892753601 | [294.0891420841217, 96.17867673188448, 1.4800511177745648] | [0.04293238] | [0.03158036] | [0.01431564] | [0.2300923] | +-------+--------------------+-------------------+------------------------------------------------------------+--------------+--------------+--------------+-------------+ neg_sampling_cpp time: 2.93s train_cf_triples shape: (2035114, 3) edge_attn_score std: 0.029202202335000038 +-------+-------------------+-------------------+------------------------------------------------------------+--------------+--------------+--------------+--------------+ | Epoch | training time | tesing time | Loss | recall | ndcg | precision | hit_ratio | +-------+-------------------+-------------------+------------------------------------------------------------+--------------+--------------+--------------+--------------+ | 12 | 993.5224199295044 | 99.09122490882874 | [289.8181935995817, 87.36584776639938, 1.4702649217797443] | [0.04387883] | [0.03194437] | [0.01445865] | [0.23103233] | +-------+-------------------+-------------------+------------------------------------------------------------+--------------+--------------+--------------+--------------+ neg_sampling_cpp time: 3.08s train_cf_triples shape: (2035114, 3) edge_attn_score std: 0.028071589767932892 +-------+--------------------+-------------------+------------------------------------------------------------+--------------+--------------+--------------+--------------+ | Epoch | training time | tesing time | Loss | recall | ndcg | precision | hit_ratio | +-------+--------------------+-------------------+------------------------------------------------------------+--------------+--------------+--------------+--------------+ | 13 | 1031.6883862018585 | 78.71787476539612 | [286.12213262170553, 77.74990151450038, 1.461867157719098] | [0.04277731] | [0.03173141] | [0.01432364] | [0.22431224] | +-------+--------------------+-------------------+------------------------------------------------------------+--------------+--------------+--------------+--------------+ neg_sampling_cpp time: 2.96s train_cf_triples shape: (2035114, 3) edge_attn_score std: 0.036324393004179 +-------+------------------+-------------------+-----------------------------------------------------------+--------------+--------------+--------------+--------------+ | Epoch | training time | tesing time | Loss | recall | ndcg | precision | hit_ratio | +-------+------------------+-------------------+-----------------------------------------------------------+--------------+--------------+--------------+--------------+ | 14 | 1013.68630027771 | 80.84501814842224 | [281.6399531438947, 68.5174616407603, 1.4546180783654563] | [0.03961291] | [0.03055458] | [0.01297813] | [0.20915209] | +-------+------------------+-------------------+-----------------------------------------------------------+--------------+--------------+--------------+--------------+ neg_sampling_cpp time: 2.85s train_cf_triples shape: (2035114, 3) edge_attn_score std: 0.09908858686685562 +-------+--------------------+------------------+--------------------------------------------------------------+------------+--------------+--------------+--------------+ | Epoch | training time | tesing time | Loss | recall | ndcg | precision | hit_ratio | +-------+--------------------+------------------+--------------------------------------------------------------+------------+--------------+--------------+--------------+ | 15 | 1066.3184566497803 | 73.2188663482666 | [276.61744425445795, 60.048723665997386, 1.4483676520176232] | [0.040843] | [0.03266636] | [0.01343813] | [0.21546215] | +-------+--------------------+------------------+--------------------------------------------------------------+------------+--------------+--------------+--------------+ neg_sampling_cpp time: 2.92s train_cf_triples shape: (2035114, 3) edge_attn_score std: 0.10782324522733688 +-------+--------------------+-------------------+------------------------------------------------------------+--------------+--------------+--------------+--------------+ | Epoch | training time | tesing time | Loss | recall | ndcg | precision | hit_ratio | +-------+--------------------+-------------------+------------------------------------------------------------+--------------+--------------+--------------+--------------+ | 16 | 1060.7459824085236 | 80.94961190223694 | [272.18506214767694, 52.8009411431849, 1.4428192895720713] | [0.04073255] | [0.03289634] | [0.01330763] | [0.21518215] | +-------+--------------------+-------------------+------------------------------------------------------------+--------------+--------------+--------------+--------------+ neg_sampling_cpp time: 2.94s train_cf_triples shape: (2035114, 3) edge_attn_score std: 0.12445466220378876 +-------+--------------------+-------------------+-------------------------------------------------------------+--------------+--------------+--------------+--------------+ | Epoch | training time | tesing time | Loss | recall | ndcg | precision | hit_ratio | +-------+--------------------+-------------------+-------------------------------------------------------------+--------------+--------------+--------------+--------------+ | 17 | 1021.2842864990234 | 82.45049953460693 | [266.87108846753836, 46.76785609871149, 1.4376955690095201] | [0.04010637] | [0.03143557] | [0.01311063] | [0.21205212] | +-------+--------------------+-------------------+-------------------------------------------------------------+--------------+--------------+--------------+--------------+ neg_sampling_cpp time: 2.86s train_cf_triples shape: (2035114, 3) edge_attn_score std: 0.15875782072544098 +-------+--------------------+-------------------+------------------------------------------------------------+--------------+--------------+--------------+--------------+ | Epoch | training time | tesing time | Loss | recall | ndcg | precision | hit_ratio | +-------+--------------------+-------------------+------------------------------------------------------------+--------------+--------------+--------------+--------------+ | 18 | 1018.3350901603699 | 86.64400720596313 | [261.9447982907295, 41.95732254907489, 1.4331976809189655] | [0.04033699] | [0.03366746] | [0.01313213] | [0.21311213] | +-------+--------------------+-------------------+------------------------------------------------------------+--------------+--------------+--------------+--------------+ neg_sampling_cpp time: 2.96s train_cf_triples shape: (2035114, 3) edge_attn_score std: 0.1483171284198761 +-------+--------------------+-------------------+-------------------------------------------------------------+--------------+--------------+--------------+--------------+ | Epoch | training time | tesing time | Loss | recall | ndcg | precision | hit_ratio | +-------+--------------------+-------------------+-------------------------------------------------------------+--------------+--------------+--------------+--------------+ | 19 | 1011.9683840274811 | 81.47498273849487 | [257.97484501451254, 37.92921615112573, 1.4288659388548695] | [0.03837464] | [0.02969172] | [0.01248562] | [0.20324203] | +-------+--------------------+-------------------+-------------------------------------------------------------+--------------+--------------+--------------+--------------+ neg_sampling_cpp time: 2.92s train_cf_triples shape: (2035114, 3) edge_attn_score std: 0.1949867159128189 +-------+--------------------+-------------------+------------------------------------------------------------+--------------+--------------+--------------+-------------+ | Epoch | training time | tesing time | Loss | recall | ndcg | precision | hit_ratio | +-------+--------------------+-------------------+------------------------------------------------------------+--------------+--------------+--------------+-------------+ | 20 | 1044.9123375415802 | 79.50294637680054 | [254.2951421365142, 34.845006006769836, 1.425107032933738] | [0.03648797] | [0.02605076] | [0.01162112] | [0.1904519] | +-------+--------------------+-------------------+------------------------------------------------------------+--------------+--------------+--------------+-------------+ neg_sampling_cpp time: 2.96s train_cf_triples shape: (2035114, 3) edge_attn_score std: 0.17475204169750214 +-------+--------------------+-------------------+------------------------------------------------------------+--------------+--------------+--------------+--------------+ | Epoch | training time | tesing time | Loss | recall | ndcg | precision | hit_ratio | +-------+--------------------+-------------------+------------------------------------------------------------+--------------+--------------+--------------+--------------+ | 21 | 1044.1717176437378 | 80.40379881858826 | [249.09019066393375, 32.05059486068785, 1.421627540839836] | [0.03421558] | [0.02507725] | [0.01076711] | [0.17792178] | +-------+--------------------+-------------------+------------------------------------------------------------+--------------+--------------+--------------+--------------+ neg_sampling_cpp time: 2.86s train_cf_triples shape: (2035114, 3) edge_attn_score std: 0.285133957862854 +-------+--------------------+-------------------+--------------------------------------------------------------+--------------+--------------+-------------+-------------+ | Epoch | training time | tesing time | Loss | recall | ndcg | precision | hit_ratio | +-------+--------------------+-------------------+--------------------------------------------------------------+--------------+--------------+-------------+-------------+ | 22 | 1053.8228340148926 | 78.07096815109253 | [245.97367100417614, 29.855773952789605, 1.4183097369968891] | [0.03258255] | [0.02205351] | [0.0101541] | [0.1696017] | +-------+--------------------+-------------------+--------------------------------------------------------------+--------------+--------------+-------------+-------------+ early stopping at 22, recall@20:0.0439

paul0728 commented 1 year ago

謝謝,我再試試看。 我想我是沒有取到最佳結果,我之前都取最後一個epoch之結果 另外,kgin的程式碼有點問題,您可能要檢查一下 run_kgin.py 和 evaluate.py的import部份都需要修正 程式才能正確執行

yuh-yang commented 1 year ago

感谢,我去检查一下