hungpthanh / GRU4REC-pytorch

An other implementation of GRU4REC using PyTorch
Apache License 2.0
243 stars 55 forks source link

GRU4REC-PyTorch

Requirements

Usage

Dataset

RecSys Challenge 2015 Dataset can be retreived from HERE

Pre processing data

The format of data is similar to that obtained from RecSys Challenge 2015:

Training and Testing

The project have a structure as below:

├── GRU4REC-pytorch
│   ├── checkpoint
│   ├── data
│   │    ├── preprocessed_data
│   │    │    ├── recSys15TrainOnly.txt
│   │    │    ├── recSys15Valid.txt
│   │    ├── raw_data
│   │    │    ├── yoochoose-clicks.dat
│   ├── lib
│   ├── main.py
│   ├── preprocessing.py
│   ├── tool.py

tool.py can be used to get 1/8 last session from yoochoose-clicks.dat

In GRU4REC-pytorch

Training

python main.py

Testing

python main.py --is_eval --load_model checkpoint/CHECKPOINT#/model_EPOCH#.pt

List of Arguments accepted

--hidden_size Number of Neurons per Layer (Default = 100)
--num_layers Number of Hidden Layers (Default = 1)
--batch_size Batch Size (Default = 50)
--dropout_input Dropout ratio at input (Default = 0)
--dropout_hidden Dropout at each hidden layer except the last one (Default = 0.5)
--n_epochs Number of epochs (Default = 10)
--k_eval Value of K used durig Recall@K and MRR@K Evaluation (Default = 20)
--optimizer_type Optimizer (Default = Adagrad)
--final_act Activation Function (Default = Tanh)
--lr Learning rate (Default = 0.01)
--weight_decay Weight decay (Default = 0)
--momentum Momentum Value (Default = 0)
--eps Epsilon Value of Optimizer (Default = 1e-6)
--loss_type Type of loss function TOP1 / BPR / TOP1-max / BPR-max / Cross-Entropy (Default: TOP1-max)
--time_sort In case items are not sorted by time stamp (Default = 0)
--model_name String of model name.
--save_dir String of folder to save the checkpoints and logs inside it (Default = /checkpoint).
--data_folder String of the directory to the folder containing the dataset.
--train_data Name of the training dataset file (Default = recSys15TrainOnly.txt)
--valid_data Name of the validation dataset file (Default = recSys15Valid.txt)
--is_eval Should be used in case of evaluation only using a checkpoint model.
--load_model String containing the checkpoint model to be used in evaluation.
--checkpoint_dir String containing directory of the checkpoints folder.

Results

Different loss functions and different parameters have been tried out and the results can be seen from HERE