Closed ralgond closed 8 months ago
Sorry for the late response, I just notice the question~
I cannot access the google drive link, can you make it to public access? You can then at me so that I can receive email notification.
The file should be accessable now.
How is the train/valid/test split from this user_history.pkl, can I just use the last item for the test set?
Fisrt of all, please install pandas.
Secondly, run code as following:
import pickle
with open("user_history.pkl", "rb") as inf:
obj = pickle.load(inf)
print(obj)
how to split the train set and valid set is up to you.
I just trained one model with the default hyper-parameter:
###############################################################################################
############################################################################################### LOCAL_ROOT="/home/jialia/tmp/unirec/UniRec" # path to UniRec ###############################################################################################
MY_DIR=$LOCAL_ROOT ALL_DATA_ROOT="$LOCAL_ROOT/data" OUTPUT_ROOT="$LOCAL_ROOT/output" MODEL_NAME='SASRec' # [ MF, AvgHist, AttHist, SVDPlusPlus, GRU, SASRec, ConvFormer, FASTConvFormer] loss_type='fullsoftmax' # [bce, bpr, softmax, ccl, fullsoftmax] distance_type='dot' # [cosine, mlp, dot] DATASET_NAME="github" max_seq_len=20 verbose=2 learning_rate=0.0002 #7532020029371717 weight_decay=0 #1e-6
cd $MY_DIR export PYTHONPATH=$PWD
DATA_TYPE='SeqRecDataset' # BaseDataset SeqRecDataset
test_protocol='one_vs_all' # 'one_vs_k' 'one_vs_all' 'session_aware'
exp_name="$MODEL_NAME-$DATASET_NAME"
ALL_RESULTS_ROOT="$OUTPUT_ROOT/$DATASET_NAME/$MODEL_NAME" mkdir -p $ALL_RESULTS_ROOT
/usr/bin/python3.9 unirec/main/main.py \ --config_dir="unirec/config" \ --model=$MODEL_NAME \ --dataloader=$DATA_TYPE \ --dataset=$DATASET_NAME \ --dataset_path=$ALL_DATA_ROOT"/"$DATASET_NAME \ --output_path=$ALL_RESULTS_ROOT"/train" \ --learning_rate=$learning_rate \ --dropout_prob=0.0 \ --embedding_size=64 \ --hidden_size=64 \ --use_pre_item_emb=0 \ --loss_type=$loss_type \ --max_seq_len=$max_seq_len \ --has_user_bias=0 \ --has_item_bias=0 \ --epochs=100 \ --early_stop=10 \ --batch_size=512 \ --n_sample_neg_train=0 \ --neg_by_pop_alpha=0 \ --valid_protocol=$test_protocol \ --test_protocol=$test_protocol \ --grad_clip_value=-1 \ --weight_decay=$weight_decay \ --history_mask_mode='autoregressive' \ --user_history_filename="user_history" \ --user_history_file_format="user-item_seq" \ --metrics="['hit@10;20', 'ndcg@10;20']" \ --key_metric="ndcg@10" \ --num_workers=4 \ --num_workers_test=0 \ --verbose=$verbose \ --exp_name=$exp_name \ --distance_type=$distance_type \ --use_wandb=0
And got the following results:
[INFO] SASRec-SASRec-github: epoch 0 evaluating [time: 50.12s, ndcg@10: 0.000291] [INFO] SASRec-SASRec-github: complete scores on valid set: hit@10:0.00072 hit@20:0.00106 ndcg@10:0.0002905489186220994 ndcg@20:0.00037539053820864944 [INFO] SASRec-SASRec-github: Saving best model at epoch 0 to /home/jialia/tmp/unirec/UniRec/output/github/SASRec/train/checkpoint_2024-01-28_095355_92/SASRec-SASRec-github.pth [INFO] SASRec-SASRec-github:
epoch 1 [INFO] SASRec-SASRec-github: epoch 1 training [time: 37.51s, train loss: 11888.1119] [INFO] SASRec-SASRec-github: one_vs_all [INFO] SASRec-SASRec-github: epoch 1 evaluating [time: 43.49s, ndcg@10: 0.060777] [INFO] SASRec-SASRec-github: complete scores on valid set: hit@10:0.122065 hit@20:0.213795 ndcg@10:0.0607769720065766 ndcg@20:0.08359101588481657 [INFO] SASRec-SASRec-github: Saving best model at epoch 1 to /home/jialia/tmp/unirec/UniRec/output/github/SASRec/train/checkpoint_2024-01-28_095355_92/SASRec-SASRec-github.pth [INFO] SASRec-SASRec-github: epoch: 1, learning rate: 0.0002 [INFO] SASRec-SASRec-github: epoch 2 [INFO] SASRec-SASRec-github: epoch 2 training [time: 37.21s, train loss: 10133.0652] [INFO] SASRec-SASRec-github: one_vs_all [INFO] SASRec-SASRec-github: epoch 2 evaluating [time: 43.41s, ndcg@10: 0.165069] [INFO] SASRec-SASRec-github: complete scores on valid set: hit@10:0.2878 hit@20:0.376515 ndcg@10:0.16506895152294968 ndcg@20:0.1875047732204879 [INFO] SASRec-SASRec-github: Saving best model at epoch 2 to /home/jialia/tmp/unirec/UniRec/output/github/SASRec/train/checkpoint_2024-01-28_095355_92/SASRec-SASRec-github.pth [INFO] SASRec-SASRec-github: epoch: 2, learning rate: 0.0002 [INFO] SASRec-SASRec-github: epoch 3 [INFO] SASRec-SASRec-github: epoch 3 training [time: 37.05s, train loss: 9490.9100] [INFO] SASRec-SASRec-github: one_vs_all [INFO] SASRec-SASRec-github: epoch 3 evaluating [time: 43.18s, ndcg@10: 0.202454] [INFO] SASRec-SASRec-github: complete scores on valid set: hit@10:0.34293 hit@20:0.448205 ndcg@10:0.2024542686899938 ndcg@20:0.22908459868123437 [INFO] SASRec-SASRec-github: Saving best model at epoch 3 to /home/jialia/tmp/unirec/UniRec/output/github/SASRec/train/checkpoint_2024-01-28_095355_92/SASRec-SASRec-github.pth [INFO] SASRec-SASRec-github: epoch: 3, learning rate: 0.0002 [INFO] SASRec-SASRec-github: epoch 4 [INFO] SASRec-SASRec-github: epoch 4 training [time: 37.24s, train loss: 9271.6420] [INFO] SASRec-SASRec-github: one_vs_all [INFO] SASRec-SASRec-github: epoch 4 evaluating [time: 41.89s, ndcg@10: 0.211230] [INFO] SASRec-SASRec-github: complete scores on valid set: hit@10:0.35965 hit@20:0.47127 ndcg@10:0.2112298167533041 ndcg@20:0.239386022708208 [INFO] SASRec-SASRec-github: Saving best model at epoch 4 to /home/jialia/tmp/unirec/UniRec/output/github/SASRec/train/checkpoint_2024-01-28_095355_92/SASRec-SASRec-github.pth [INFO] SASRec-SASRec-github: epoch: 4, learning rate: 0.0002 [INFO] SASRec-SASRec-github: epoch 5 [INFO] SASRec-SASRec-github: epoch 5 training [time: 37.33s, train loss: 9151.2752] [INFO] SASRec-SASRec-github: one_vs_all [INFO] SASRec-SASRec-github: epoch 5 evaluating [time: 42.59s, ndcg@10: 0.210935] [INFO] SASRec-SASRec-github: complete scores on valid set: hit@10:0.359915 hit@20:0.474425 ndcg@10:0.2109346346254805 ndcg@20:0.2399600573056946 [INFO] SASRec-SASRec-github: No better score in the epoch. Patience: 1 / 10 [INFO] SASRec-SASRec-github: epoch: 5, learning rate: 0.0002 [INFO] SASRec-SASRec-github: epoch 6 [INFO] SASRec-SASRec-github: epoch 6 training [time: 37.24s, train loss: 9087.8924] [INFO] SASRec-SASRec-github: one_vs_all [INFO] SASRec-SASRec-github: epoch 6 evaluating [time: 43.18s, ndcg@10: 0.216422] [INFO] SASRec-SASRec-github: complete scores on valid set: hit@10:0.37097 hit@20:0.481085 ndcg@10:0.21642232238558282 ndcg@20:0.24433307328187334 [INFO] SASRec-SASRec-github: Saving best model at epoch 6 to /home/jialia/tmp/unirec/UniRec/output/github/SASRec/train/checkpoint_2024-01-28_095355_92/SASRec-SASRec-github.pth [INFO] SASRec-SASRec-github: epoch: 6, learning rate: 0.0002 [INFO] SASRec-SASRec-github: epoch 7 [INFO] SASRec-SASRec-github: epoch 7 training [time: 37.00s, train loss: 9044.4685] [INFO] SASRec-SASRec-github: one_vs_all [INFO] SASRec-SASRec-github: epoch 7 evaluating [time: 43.80s, ndcg@10: 0.218103] [INFO] SASRec-SASRec-github: complete scores on valid set: hit@10:0.3739 hit@20:0.48573 ndcg@10:0.2181032363370788 ndcg@20:0.24637894420165377 [INFO] SASRec-SASRec-github: Saving best model at epoch 7 to /home/jialia/tmp/unirec/UniRec/output/github/SASRec/train/checkpoint_2024-01-28_095355_92/SASRec-SASRec-github.pth [INFO] SASRec-SASRec-github: epoch: 7, learning rate: 0.0002 [INFO] SASRec-SASRec-github: epoch 8 [INFO] SASRec-SASRec-github: epoch 8 training [time: 37.14s, train loss: 9008.2174] [INFO] SASRec-SASRec-github: one_vs_all [INFO] SASRec-SASRec-github: epoch 8 evaluating [time: 42.79s, ndcg@10: 0.227004] [INFO] SASRec-SASRec-github: complete scores on valid set: hit@10:0.38816 hit@20:0.497215 ndcg@10:0.227003882833139 ndcg@20:0.25460645051101377 [INFO] SASRec-SASRec-github: Saving best model at epoch 8 to /home/jialia/tmp/unirec/UniRec/output/github/SASRec/train/checkpoint_2024-01-28_095355_92/SASRec-SASRec-github.pth [INFO] SASRec-SASRec-github: epoch: 8, learning rate: 0.0002 [INFO] SASRec-SASRec-github: epoch 9 [INFO] SASRec-SASRec-github: epoch 9 training [time: 37.28s, train loss: 8964.1008] [INFO] SASRec-SASRec-github: one_vs_all [INFO] SASRec-SASRec-github: epoch 9 evaluating [time: 43.93s, ndcg@10: 0.232188] [INFO] SASRec-SASRec-github: complete scores on valid set: hit@10:0.39465 hit@20:0.5038 ndcg@10:0.2321883789085995 ndcg@20:0.259833469948302
this is the same setting, change to GRU for 30 epochs: hit@10 0.38521 hit@20 0.49989 ndcg@10 0.22414061177355216 ndcg@20 0.25317436009739047
nice job
I can not find the spliting logic for train and valid dataset. Are you using the whole dataset as valid dataset?
My splitting schema is making records in train set when user_id < 200000*0.1, others are in valid set.
I use the leave-one-out split for the valid dataset.
I must have misused UniRec.
Another question: I have a test dataset, how to recommend the top5 items for each session?
I thought you mean the topk_reco.py: https://github.com/microsoft/UniRec/blob/main/examples/more-examples/kddcup2023/topk_score.sh
hi,there
I am training a sequential recommender,the user_history file is as following:
https://drive.google.com/file/d/1kV3Wk7KUvRVONJnt0quuSs39XVWRpi6w/view?usp=drive_link
According to my historical experience,the hit@10 should be about 0.2, (I got this result by Transformers4Rec), but UniRec give out the result is 0.004. I almostly do not change any thing about configuration,and just change SASRec to GRU