easylearningscores / PastNet

MM'2024
4 stars 0 forks source link

请问傅里叶模块是如何发挥作用的? #2

Open jialiangZ opened 9 months ago

jialiangZ commented 9 months ago
class PastNetModel(nn.Module):
    def __init__(self, 
                 args,
                 shape_in, 
                 hid_T=256, 
                 N_T=8, 
                 incep_ker=[3, 5, 7, 11], 
                 groups=8, 
                 res_units=64, 
                 res_layers=2, 
                 embedding_nums=512, 
                 embedding_dim=64):
        super(PastNetModel, self).__init__()
        T, C, H, W = shape_in
        self.DST_module = DST(in_channel=C,
                             res_units=res_units,
                             res_layers=res_layers,
                             embedding_dim=embedding_dim,
                             embedding_nums=embedding_nums)

        self.FPG_module = FPG(img_size=64,
                              patch_size=16,
                              in_channels=1,
                              out_channels=1,
                              embed_dim=128,
                              input_frames=10,
                              depth=1,
                              mlp_ratio=2.,
                              uniform_drop=False,
                              drop_rate=0.,
                              drop_path_rate=0.,
                              norm_layer=None,
                              dropcls=0.)

        if args.load_pred_train:
            print_log("Load Pre-trained Model.")
            self.vq_vae.load_state_dict(torch.load("./models/vqvae.ckpt"), strict=False)

        if args.freeze_vqvae:
            print_log(f"Params of VQVAE is freezed.")
            for p in self.vq_vae.parameters():
                p.requires_grad = False
        self.DynamicPro = DynamicPropagation(T*64, hid_T, N_T, incep_ker, groups)

    def forward(self, input_frames):
        B, T, C, H, W = input_frames.shape
        pde_features = self.FPG_module(input_frames)
        input_features = input_frames.view([B * T, C, H, W])
        encoder_embed = self.DST_module._encoder(input_features)
        z = self.DST_module._pre_vq_conv(encoder_embed)
        vq_loss, Latent_embed, _, _ = self.DST_module._vq_vae(z)

        _, C_, H_, W_ = Latent_embed.shape
        Latent_embed = Latent_embed.reshape(B, T, C_, H_, W_)

        hidden_dim = self.DynamicPro(Latent_embed)
        B_, T_, C_, H_, W_ = hidden_dim.shape
        hid = hidden_dim.reshape([B_ * T_, C_, H_, W_])

        predicti_feature = self.DST_module._decoder(hid)
        predicti_feature = predicti_feature.reshape([B, T, C, H, W]) + pde_features

        return predicti_feature

在forward函数中,您直接将pde_features与predicti_feature相加,请问这里的pde_features是做了什么角色呢?

easylearningscores commented 1 month ago

感谢对我们工作的关注,傅里叶模块在实验中,主要是弥补低频信息