zhouhaoyi / Informer2020

The GitHub repository for the paper "Informer" accepted by AAAI 2021.
Apache License 2.0
5.27k stars 1.1k forks source link

Informer和Autoformer的解码器部分完全没有发挥应有的作用 #564

Open CXL-edu opened 11 months ago

CXL-edu commented 11 months ago

基于Transformer的模型,其解码器在训练的时候可以采用teacher-forcing进行加速训练,在后期还可以根据epoch控制真实值和预测值的比例。但是在评估、推断/预测阶段,解码器的输入应该根据其上一时刻的输出迭代更新,再继续作为输入,以达到自回归的目的。

dec_in : [batch, pred_len, d_model],为简化,假设batch和d_model都为1
dec_out : [batch, pred_len, d_model]

dec_in = [[[0], [0] ... [0]]]  # 初始化的解码器输入,用来预测第一个时刻的值 ,模型输出为dec_out_1
...
dec_in = [[dec_out_1[0, 0], [0] ... [0]]]  # 第一次更新后的解码器输入,用来预测第二个时刻的值,模型输出为dec_out_2
...
dec_in = [[dec_out_1[0, 0], dec_out_2[0, 1] ... [0]]  # 不断迭代,直到达到预测长度
efg001 commented 1 month ago

I had the same idea. It's a simple change: do you want to give a try? I am seeing the model overfitting training dataset way earlier after teacher forcing is introduced. It's pretty useless at least for the dataset I am running with.

I can't find any paper applying teacher-forcing on time series data.

If I have to guess, they probably made this change intentionally not to use teacher forcing because intuitively, teacher forcing works well with task that human excel at because it promote memorization: you are feeding the model the answer and allowing it remember it. For generative model, this is good because you are telling the model what kind of sentence you expect, same with image recognition and translation, but for time series forecast I think you do not want the model to do that(memorize the answer)

Just to be crystal clear. Here is their setup:

  Seq    pred
[------][---]
  label
   [---]

to introduce teacher forcing you need to use the "shift by one" approach in the original transformer paper so I changed it to this: (instead of having pred after Seq, you shift seq right by one(or more) to get Pred, the overlapping part allows for teacher-forcing )

   Seq  
[------]
   label  
   [----]
     pred 
     [------]

note that in the original paper, seq is in a different language(source language) and decoder is allow to attend to every token in the seq. Here we are dealing with time series forecasting so if you do the above, you are leaking answer from encoder and there is no trivial way to get around of it.

I removed encoder and made a decoder-only model out of it

Another approach would be to shift Seq left to avoid overlap but without overlap you can't do teacher enforcing