GestaltCogTeam / STEP

Code for our SIGKDD'22 paper Pre-training-Enhanced Spatial-Temporal Graph Neural Network For Multivariate Time Series Forecasting.
Apache License 2.0
325 stars 35 forks source link

Tsformer pretrain question #18

Closed Jimmy-7664 closed 1 year ago

Jimmy-7664 commented 1 year ago

In the file tsformer_runner.py, it uses method "test" to do the test while training, however, I found that it only use the last batch of the test_dataloader. Is there something wrong with the for loop? I'm I wrong? ` @torch.no_grad() @master_only def test(self): """Evaluate the model.

    Args:
        train_epoch (int, optional): current epoch if in training process.
    """

    for _, data in enumerate(self.test_data_loader):
        forward_return = self.forward(data=data, epoch=None, iter_num=None, train=False)
    # re-scale data
    prediction_rescaled = SCALER_REGISTRY.get(self.scaler["func"])(forward_return[0], **self.scaler["args"])
    real_value_rescaled = SCALER_REGISTRY.get(self.scaler["func"])(forward_return[1], **self.scaler["args"])
    # metrics
    for metric_name, metric_func in self.metrics.items():
        metric_item = metric_func(prediction_rescaled, real_value_rescaled, null_val=self.null_val)
        self.update_epoch_meter("test_"+metric_name, metric_item.item())`
zezhishao commented 1 year ago

Thanks for your question. I have checked and this is a bug... But this doesn't affect the final result because this metric is not used for anything. Thanks again for reporting this bug, I'll fix it right away!

Jimmy-7664 commented 1 year ago

Thanks for your reply, so the right test method is the test method in the file "base_tsf_runner.py",` def test(self): """Evaluate the model.

    Args:
        train_epoch (int, optional): current epoch if in training process.
    """

    # test loop
    prediction = []
    real_value = []
    for _, data in enumerate(self.test_data_loader):
        forward_return = self.forward(data, epoch=None, iter_num=None, train=False)
        prediction.append(forward_return[0])        # preds = forward_return[0]
        real_value.append(forward_return[1])        # testy = forward_return[1]
    prediction = torch.cat(prediction, dim=0)
    real_value = torch.cat(real_value, dim=0)
    # re-scale data
    prediction = SCALER_REGISTRY.get(self.scaler["func"])(
        prediction, **self.scaler["args"])
    real_value = SCALER_REGISTRY.get(self.scaler["func"])(
        real_value, **self.scaler["args"])
    # summarize the results.
    # test performance of different horizon
    for i in self.evaluation_horizons:
        # For horizon i, only calculate the metrics **at that time** slice here.
        pred = prediction[:, i, :]
        real = real_value[:, i, :]
        # metrics
        metric_results = {}
        for metric_name, metric_func in self.metrics.items():
            metric_item = self.metric_forward(metric_func, [pred, real])
            metric_results[metric_name] = metric_item.item()
        log = "Evaluate best model on test data for horizon " + \
            "{:d}, Test MAE: {:.4f}, Test RMSE: {:.4f}, Test MAPE: {:.4f}"
        log = log.format(
            i+1, metric_results["MAE"], metric_results["RMSE"], metric_results["MAPE"])
        self.logger.info(log)
    # test performance overall
    for metric_name, metric_func in self.metrics.items():
        metric_item = self.metric_forward(metric_func, [prediction, real_value])
        self.update_epoch_meter("test_"+metric_name, metric_item.item())
        metric_results[metric_name] = metric_item.item()

` I'm I right?

zezhishao commented 1 year ago

No, the test function in the base_tsf_runner is designed for the Time Series Forecasting (TSF) problem, which is not compatible with the reconstruction task in the pre-training stage. Actually, I think you can fix this by adding a Tab for lines 84~90 like:

    @torch.no_grad()
    @master_only
    def test(self):
        """Evaluate the model.

        Args:
            train_epoch (int, optional): current epoch if in training process.
        """

        for _, data in enumerate(self.test_data_loader):
            forward_return = self.forward(data=data, epoch=None, iter_num=None, train=False)
            # re-scale data
            prediction_rescaled = SCALER_REGISTRY.get(self.scaler["func"])(forward_return[0], **self.scaler["args"])
            real_value_rescaled = SCALER_REGISTRY.get(self.scaler["func"])(forward_return[1], **self.scaler["args"])
            # metrics
            for metric_name, metric_func in self.metrics.items():
                metric_item = metric_func(prediction_rescaled, real_value_rescaled, null_val=self.null_val)
                self.update_epoch_meter("test_"+metric_name, metric_item.item())
Jimmy-7664 commented 1 year ago

Thanks for your answering, I got your idea. : )

zezhishao commented 1 year ago

This bug is now fixed. Thanks again for your report!