recommenders-team / recommenders

Best Practices on Recommendation Systems
https://recommenders-team.github.io/recommenders/intro.html
MIT License
18.43k stars 3.04k forks source link

Refactor fit method in deeprec base model #539

Open miguelgfierro opened 5 years ago

miguelgfierro commented 5 years ago

Description

Question to @Leavingseason, in this method:

def fit(self, train_file, valid_file, test_file=None):
        """Fit the model with train_file. Evaluate the model on valid_file per epoch to observe the training status.
        If test_file is not None, evaluate it too.

        Args:
            train_file (str): training data set.
            valid_file (str): validation set.
            test_file (str): test set.

        Returns:
            obj: An instance of self.
        """
        if self.hparams.write_tfevents:
            self.writer = tf.summary.FileWriter(
                self.hparams.SUMMARIES_DIR, self.sess.graph
            )

        train_sess = self.sess
        for epoch in range(1, self.hparams.epochs + 1):
            step = 0
            self.hparams.current_epoch = epoch

            epoch_loss = 0
            train_start = time.time()
            for batch_data_input in self.iterator.load_data_from_file(train_file):
                step_result = self.train(train_sess, batch_data_input)
                (_, step_loss, step_data_loss, summary) = step_result
                if self.hparams.write_tfevents:
                    self.writer.add_summary(summary, step)
                epoch_loss += step_loss
                step += 1
                if step % self.hparams.show_step == 0:
                    print(
                        "step {0:d} , total_loss: {1:.4f}, data_loss: {2:.4f}".format(
                            step, step_loss, step_data_loss
                        )
                    )

            train_end = time.time()
            train_time = train_end - train_start

            if self.hparams.save_model:
                if epoch % self.hparams.save_epoch == 0:
                    checkpoint_path = self.saver.save(
                        sess=train_sess,
                        save_path=self.hparams.MODEL_DIR + "epoch_" + str(epoch),
                    )

            eval_start = time.time()
            train_res = self.run_eval(train_file)
            eval_res = self.run_eval(valid_file)
            train_info = ", ".join(
                [
                    str(item[0]) + ":" + str(item[1])
                    for item in sorted(train_res.items(), key=lambda x: x[0])
                ]
            )
            eval_info = ", ".join(
                [
                    str(item[0]) + ":" + str(item[1])
                    for item in sorted(eval_res.items(), key=lambda x: x[0])
                ]
            )
            if test_file is not None:
                test_res = self.run_eval(test_file)
                test_info = ", ".join(
                    [
                        str(item[0]) + ":" + str(item[1])
                        for item in sorted(test_res.items(), key=lambda x: x[0])
                    ]
                )
            eval_end = time.time()
            eval_time = eval_end - eval_start

            if test_file is not None:
                print(
                    "at epoch {0:d}".format(epoch)
                    + " train info: "
                    + train_info
                    + " eval info: "
                    + eval_info
                    + " test info: "
                    + test_info
                )
            else:
                print(
                    "at epoch {0:d}".format(epoch)
                    + " train info: "
                    + train_info
                    + " eval info: "
                    + eval_info
                )
            print(
                "at epoch {0:d} , train time: {1:.1f} eval time: {2:.1f}".format(
                    epoch, train_time, eval_time
                )
            )

        if self.hparams.write_tfevents:
            self.writer.close()

        return self

would it be ok if we remove the test_file and the code that evaluates in the test file and make valid_file optional?

The reason is for consistency with the rest of the methods.

However, I wouldn't like to touch a lot the base class in case you are using the test intensively.

FYI @anargyri @yexing99

Leavingseason commented 5 years ago

Yes, there existing many machine learning training process, and each of them makes sense to some extent.
We can make the code style consistent. Can you point me to the "rest of the methods"? I can take a look at first. Previously I mainly followed my own conventions to program.

miguelgfierro commented 5 years ago

one of the patterns that we follow is single responsibility: https://github.com/Microsoft/Recommenders/wiki/Coding-Guidelines#single-responsibility, there there is an example for splitting train and test procedures

anargyri commented 5 years ago

An example is https://github.com/Microsoft/Recommenders/blob/master/reco_utils/recommender/sar/sar_singlenode.py Basically the fit(), recommend_k_items() and predict() methods are separate. It is similar to the scikit-learn convention.