thuml / Time-Series-Library

A Library for Advanced Deep Time Series Models.
MIT License
6.38k stars 1.02k forks source link

关于long-term-forecasting训练中的val loss、test loss可能有误的问题 #537

Open shiwt03 opened 2 days ago

shiwt03 commented 2 days ago

exp/exp_long_term_forecasting.py 中vali函数是这样写的:

    def vali(self, vali_data, vali_loader, criterion):
        total_loss = []
        self.model.eval()
        with torch.no_grad():
            for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(vali_loader):

                # some inference code

                pred = outputs.detach().cpu()
                true = batch_y.detach().cpu()

                loss = criterion(pred, true)

                total_loss.append(loss)
        total_loss = np.average(total_loss)
        self.model.train()
        return total_loss

同时dataloader的drop_last均被设置为False,这样是否会由于取了“平均的平均”造成val loss、test loss略低于实际的情况? 我在实验中确实观察到了训练中的test loss与最终test loss存在一个微小但不可忽略的gap的现象。

虽然这并不影响两个epoch之间val loss的相对值和early stop的判断,训练中也不应该参考test loss,但改正这个问题有利于在训练过程中与其他baseline进行比较以提高研究效率。

我的修改如下,不知道这是否合理?

    def vali(self, vali_data, vali_loader, criterion):
        self.model.eval()
        loss_sum = 0.0
        samples = 0
        with torch.no_grad():
            for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(vali_loader):

                # some inference code

                pred = outputs.detach().cpu()
                true = batch_y.detach().cpu()

                loss = criterion(pred, true)
                loss_sum += loss * pred.size(0)
                samples += pred.size(0)

        total_loss = loss_sum / samples
        self.model.train()
        return total_loss

似乎在classification任务中也有类似问题,希望能得到解答,谢谢!

Musongwhk commented 2 days ago

你观察得非常细心,十分感谢!以下是我个人的看法。 其一,你所说的取“平均的平均”得到的计算结果其实并不一定比平均值低,也有可能比平均值高(我们可以拿1至10这十个数字为例验证); 其二,val loss的主要作用是帮助我们选择最佳的模型架构、超参数或正则化方法,也可以用来判断是否需要早停以防止过拟合。正如你所说,val loss的计算其实并不会影响我们对于模型训练效果的相对判断和早停功能; 其三,不同baseline之间的比较应当关注测试集上的效果test loss。我们的test方法中计算损失就是按照你修改的思路实现的,求的是所有样本批次损失的平均值(包括均方误差MSE、平均绝对误差MAE等),而非“平均的平均”。 再次感谢你细心的观察和分析,有任何问题欢迎继续讨论!