ZhengyaoJiang / PGPortfolio

PGPortfolio: Policy Gradient Portfolio, the source code of "A Deep Reinforcement Learning Framework for the Financial Portfolio Management Problem"(https://arxiv.org/pdf/1706.10059.pdf).
GNU General Public License v3.0
1.74k stars 750 forks source link

Tradertrainer's _evaluate method does not support set_name "validation" #96

Open ieow opened 6 years ago

ieow commented 6 years ago

_evaluate method in Tradertrainer is called with args {set_name : "validation"} in rollingtrainer line 37 is not supported by the method _evaluate.

_evaluate method only support set_name "test" and "training"

in tradertrainer.py line 74 ``

def _evaluate(self, set_name, *tensors):
    if set_name == "test":
        feed = self.test_set
    elif set_name == "training":
        feed = self.training_set
    else:
        raise ValueError()
    result = self._agent.evaluate_tensors(feed["X"],feed["y"],last_w=feed["last_w"],
                                          setw=feed["setw"], tensors=tensors)
    return result

``

in rollingtrainer.py line 32

``

def __rolling_logging(self):
    fast_train = self.train_config["fast_train"]
    if not fast_train:
        tflearn.is_training(False, self._agent.session)

        v_pv, v_log_mean = self._evaluate("validation",
                                          self._agent.portfolio_value,
                                          self._agent.log_mean)
        t_pv, t_log_mean = self._evaluate("test", self._agent.portfolio_value, self._agent.log_mean)
        loss_value = self._evaluate("training", self._agent.loss)

        logging.info('training loss is %s\n' % loss_value)
        logging.info('the portfolio value on validation asset is %s\nlog_mean is %s\n' %
                     (v_pv,v_log_mean))
        logging.info('the portfolio value on test asset is %s\n mean is %s' % (t_pv,t_log_mean))

``