wtysos11 / blogWiki

Use to store public paper and organize them.
17 stars 4 forks source link

记一次python内嵌套函数作用域报错问题 #162

Open wtysos11 opened 3 years ago

wtysos11 commented 3 years ago

报错信息:local variable 'batch_x_mark' referenced before assignment 问题 问题描述:在函数composeTest中文我又实现了一个函数getResult,这个函数定义的位置在一个dataloader的循环内。 我的想法是函数每次都能访问到batch_x_mark等指标,这样我就不用写太多参数,但结果是访问不到。

我自行测试了一下作用域,确定了nested function内是可以正常访问外部变量的。但是在我的情况下,我只能够访问到参数和最外层的self。感觉哪里有问题。

原代码(Informer照着改的测试代码)

    def composeTest(self,low_model,high_model, setting):
        '''
            验证结果为两个模型的和
        '''
        test_data, test_loader = self._get_data(flag='test')

        low_model.eval()
        high_model.eval()

        preds = []
        trues = []

        for i, (batch_x,batch_y,batch_x_mark,batch_y_mark,batch_pivot,batch_min,batch_max) in enumerate(test_loader):
            if self.args.scale_type == "batch":
                batch_x = miniBatchNormalization(batch_x)
                batch_y = miniBatchNormalization(batch_y)

            def getResult(model,x,y):
                x = x.float().to(self.device)
                y = y.float()
                batch_x_mark = batch_x_mark.float().to(self.device)
                batch_y_mark = batch_y_mark.float().to(self.device)

                # decoder input
                dec_inp = torch.zeros_like(y[:,-self.args.pred_len:,:]).float()
                dec_inp = torch.cat([y[:,:self.args.label_len,:], dec_inp], dim=1).float().to(self.device)
                # encoder - decoder
                if self.args.output_attention:
                    outputs = model(x, batch_x_mark, dec_inp, batch_y_mark)[0]
                else:
                    outputs = model(x, batch_x_mark, dec_inp, batch_y_mark)
                # ANCHOR token修改
                y = y[:,-self.args.pred_len:,:].to(self.device)
                # ! 修改部分(改回去了)
                # batch_y = batch_y.to(self.device)

                pred = outputs.detach().cpu().numpy()#.squeeze()
                true = y.detach().cpu().numpy()#.squeeze()
                return pred,true
            # TODO 验证过程与效果
            lowFreqPred,lowFreqTrue = getResult(low_model,batch_x[:,:,0],batch_y[:,:,0])
            highFreqPred,highFreqTrue = getResult(high_model,batch_x[:,:,1],batch_y[:,:,1])
            pred = lowFreqPred + highFreqPred
            true = lowFreqTrue + highFreqTrue
            # ! 修改部分(改回去了)
            # pred = outputs[:,-self.args.pred_len:,:].detach().cpu().numpy()#.squeeze()
            # true = batch_y[:,-self.args.pred_len:,:].detach().cpu().numpy()#.squeeze()
            # ANCHOR 数据复原操作
            preds.append(reverseData(pred,batch_pivot,batch_min,batch_max))
            trues.append(reverseData(true,batch_pivot,batch_min,batch_max))

        finishDataReadingByLastCacheName() # NOTE EEMD cache set True
        preds = np.array(preds)
        trues = np.array(trues)
        print('test shape:', preds.shape, trues.shape)
        preds = preds.reshape(-1, preds.shape[-2], preds.shape[-1])
        trues = trues.reshape(-1, trues.shape[-2], trues.shape[-1])
        print('test shape:', preds.shape, trues.shape)

        # result save
        # ANCHOR 强耦合警报。此处与automated.py是联动的
        folder_path = './results/' + setting +'/'
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)

        mae, mse, rmse, mape, mspe,nrmse,smape,cover,coverTotal = metric(preds, trues)
        print('mse:{}, rmse:{}, mape:{}, nrmse:{}, smape:{}, cover:{},coverTotal:{}'.format(mse,rmse, mape,nrmse,smape,cover,coverTotal))

        # np.save(folder_path+'metrics.npy', np.array([mae, mse, rmse, mape, mspe,nrmse,smape]))
        np.save(folder_path+'metrics.npy', np.array([mse, rmse, mape,nrmse,smape,cover,coverTotal]))
        np.save(folder_path+'pred.npy', preds)
        np.save(folder_path+'true.npy', trues)

        return