hidasib / GRU4Rec

GRU4Rec is the original Theano implementation of the algorithm in "Session-based Recommendations with Recurrent Neural Networks" paper, published at ICLR 2016 and its follow-up "Recurrent Neural Networks with Top-k Gains for Session-based Recommendations". The code is optimized for execution on the GPU.
Other
747 stars 222 forks source link

No hidden state reset in get_metrics #38

Closed JustAnotherDataScientist closed 4 years ago

JustAnotherDataScientist commented 4 years ago

In current version of the code, it seems that there is no reset of hidden states for ending sessions in test batch.

I modified on my local env the get_metrics function to the following :

def get_metrics(model, args, train_generator_map, recall_k=10, mrr_k=10):
    test_dataset = SessionDataset(args.test_data, itemmap=train_generator_map)
    test_generator = SessionDataLoader(test_dataset, batch_size=args.batch_size)

    n = 0
    rec_sum = 0
    mrr_sum = 0

    with tqdm(total=args.test_samples_qty) as pbar:
        for feat, label, mask in test_generator:
            **real_mask = np.ones((args.batch_size, 1))
            for elt in mask:
                real_mask[elt, :] = 0

            hidden_states = get_states(model)[0]
            hidden_states = np.multiply(real_mask, hidden_states)
            hidden_states = np.array(hidden_states, dtype=np.float32)
            model.layers[1].reset_states(hidden_states)**

            target_oh = to_categorical(label, num_classes=args.train_n_items)
            input_oh = to_categorical(feat,  num_classes=args.train_n_items)
            input_oh = np.expand_dims(input_oh, axis=1)

            pred = model.predict(input_oh, batch_size=args.batch_size)

            for row_idx in range(feat.shape[0]):
                pred_row = pred[row_idx] 
                label_row = target_oh[row_idx]

                rec_idx = pred_row.argsort()[-recall_k:][::-1]
                mrr_idx = pred_row.argsort()[-mrr_k:][::-1]
                tru_idx = label_row.argsort()[-1:][::-1]

                n += 1

                if tru_idx[0] in rec_idx:
                    rec_sum += 1

                if tru_idx[0] in mrr_idx:
                    mrr_sum += 1/int((np.where(mrr_idx == tru_idx[0])[0]+1))

            pbar.set_description("Evaluating model")
            pbar.update(test_generator.done_sessions_counter)

    recall = rec_sum/n
    mrr = mrr_sum/n
    return (recall, recall_k), (mrr, mrr_k)

Can someone confirm that there was an error, and that my correction worked ?

eveTu commented 4 years ago

I thought on those steps as well and I agree with your correction actually.

hidasib commented 4 years ago

I think that this comment was intended for a different repo (?), because this codebase has no get_metrics method. (Also, the code above calls methods that are non-existent in this implementation; e.g. model.layers[1].reset_hidden_states --> assuming that model is a GRU4Rec object, its layers attribute is a list of integers that signal the size of each layer and not the list of actual GRU layers, and thus they don't have a reset_hidden_states method. Many other things don't align with this repo either.)

The metrics (recall & MRR) can be measured by using either evaluate_sessions_batch or evaluate_gpu from evaluation.py. The latter (evaluate_gpu) is recommended, because it has higher GPU utilization. The hidden states are reset to zero in lines 184-189. If you still use evaluate_sessions_batch, the hidden state is reset to zero inside the predict_next_batch method in gru4rec.py whenever any of the session indexes change (see lines 657-662).

((There is a third evaluation method in evaluation.py, but that one can't be used to evaluate the GRU4Rec model. It is there for reproducing experiments with the baseline.)

eveTu commented 4 years ago

Ho, sorry for that. the code comes from gru4rec Keras implementation

JustAnotherDataScientist commented 4 years ago

Very sorry, eveTu is right, the code comes from gru4rec Keras ! I close the issue right away. Anyways, thanks for sharing your work Hidasi.