ZhengChang467 / STRPM

STRPM: A Spatiotemporal Residual Predictive Model for High-Resolution Video Prediction, CVPR2022
MIT License
17 stars 3 forks source link

关于mask_true的作用以及tau的设置问题 #3

Open chufall opened 2 years ago

chufall commented 2 years ago

您好,感谢您的分享,我对代码有几处疑问(英文怕解释不清)

  1. mask_true的作用 在strpm.py的forward函数中,下面这段代码使用到了mask_true if t < self.configs.input_length: net = frames[:, t] #t=0 1,192,64,64 取某一幅图画 else: time_diff = t - self.configs.input_length #input_length = 4 , 表示预测周期,每4张预测1张 net = mask_true[:, time_diff] frames[:, t] + (1 - mask_true[:, time_diff]) x_gen 调试后发现train时,mask_true都是1,test时mask_true都是0: 问题1:是不是可以理解为,训练时,都输入原始图片,而测试时,输入Input_length长度的图片以后,只使用预测出的图片来循环预测? 问题2:但既然是那样的话,在训练代码的loss计算时,为何用了所有的9张图输出,不否应该把t=0,1,2,3时的输出剔除掉来计算loss更好?

  2. tau的设置 tau的设置是不是与input_length有关? 我的场景是用8张图来预测后面2张图,input_length=8,total_length=10, tau是不是应该设为9为好? 这块没有特别理解.

谢谢!盼复! Qc 2020.6.20

ZhengChang467 commented 2 years ago

你好,感谢你对我们工作的关注!关于你的问题,我们做出如下的回复:

  1. STRPM_run.py schedule_sampling函数会根据当前的训练进度动态更新mask_true的值,训练初期mask_true全为1,训练后期Mask_true接近于0,因为刚开始的预测值质量低,直接当做下一个时刻的输入会使网络难以训练,并非训练时,都输入原始图片
  2. 网络之中预测单元会在时域上进行复制,即每一个时刻的预测单元的参数实际上是一致的,因此每一个时刻都需要进行约束,而非仅仅预测时刻需要约束
  3. 严格来说,tau和input_length并没有直接关系,但是tau越大,性能会越高,根据你的场景Tau设为8应该是性能比较好的,但是计算量也会比较大一些,可以考虑将性能和tau做一个均衡处理。 希望对你有所帮助!