Open jialiangZ opened 1 year ago
class PastNetModel(nn.Module): def __init__(self, args, shape_in, hid_T=256, N_T=8, incep_ker=[3, 5, 7, 11], groups=8, res_units=64, res_layers=2, embedding_nums=512, embedding_dim=64): super(PastNetModel, self).__init__() T, C, H, W = shape_in self.DST_module = DST(in_channel=C, res_units=res_units, res_layers=res_layers, embedding_dim=embedding_dim, embedding_nums=embedding_nums) self.FPG_module = FPG(img_size=64, patch_size=16, in_channels=1, out_channels=1, embed_dim=128, input_frames=10, depth=1, mlp_ratio=2., uniform_drop=False, drop_rate=0., drop_path_rate=0., norm_layer=None, dropcls=0.) if args.load_pred_train: print_log("Load Pre-trained Model.") self.vq_vae.load_state_dict(torch.load("./models/vqvae.ckpt"), strict=False) if args.freeze_vqvae: print_log(f"Params of VQVAE is freezed.") for p in self.vq_vae.parameters(): p.requires_grad = False self.DynamicPro = DynamicPropagation(T*64, hid_T, N_T, incep_ker, groups) def forward(self, input_frames): B, T, C, H, W = input_frames.shape pde_features = self.FPG_module(input_frames) input_features = input_frames.view([B * T, C, H, W]) encoder_embed = self.DST_module._encoder(input_features) z = self.DST_module._pre_vq_conv(encoder_embed) vq_loss, Latent_embed, _, _ = self.DST_module._vq_vae(z) _, C_, H_, W_ = Latent_embed.shape Latent_embed = Latent_embed.reshape(B, T, C_, H_, W_) hidden_dim = self.DynamicPro(Latent_embed) B_, T_, C_, H_, W_ = hidden_dim.shape hid = hidden_dim.reshape([B_ * T_, C_, H_, W_]) predicti_feature = self.DST_module._decoder(hid) predicti_feature = predicti_feature.reshape([B, T, C, H, W]) + pde_features return predicti_feature
在forward函数中,您直接将pde_features与predicti_feature相加,请问这里的pde_features是做了什么角色呢?
感谢对我们工作的关注,傅里叶模块在实验中,主要是弥补低频信息
在forward函数中,您直接将pde_features与predicti_feature相加,请问这里的pde_features是做了什么角色呢?