easylearningscores / EarthFarseer

AAAI'2024
21 stars 0 forks source link

你好,想问一下TeDev的堆叠细节 #6

Open li0759 opened 3 months ago

li0759 commented 3 months ago

作者你好,我留意到您模型示意图里面的TeDev模块是多个堆叠的,请问是包括(MFSC、傅里叶模块、Norm)这三个子模块在一起堆叠n层;还是说MFSC、Norm只放置一层,只有中间的傅里叶模块堆叠n层呢?

easylearningscores commented 3 months ago

感谢您对我们工作的关注,参阅TeDev的代码,其中MFSC的层数取决于超参N_T,傅里叶模块的层数超参取决于 for i in range(12),12表示12层,可以根据个人需求进行修改

class TeDev(nn.Module): def init(self, channel_in, channel_hid, N_T, h, w, incep_ker=[3, 5, 7, 11], groups=8): super(TeDev, self).init()

    self.N_T = N_T
    enc_layers = [Inception(channel_in, channel_hid // 2, channel_hid, incep_ker=incep_ker, groups=groups)]
    for i in range(1, N_T - 1):
        enc_layers.append(Inception(channel_hid, channel_hid // 2, channel_hid, incep_ker=incep_ker, groups=groups))
    enc_layers.append(Inception(channel_hid, channel_hid // 2, channel_hid, incep_ker=incep_ker, groups=groups))

    dec_layers = [Inception(channel_hid, channel_hid // 2, channel_hid, incep_ker=incep_ker, groups=groups)]
    for i in range(1, N_T - 1):
        dec_layers.append(
            Inception(2 * channel_hid, channel_hid // 2, channel_hid, incep_ker=incep_ker, groups=groups))
    dec_layers.append(Inception(2 * channel_hid, channel_hid // 2, channel_in, incep_ker=incep_ker, groups=groups))
    norm_layer = partial(nn.LayerNorm, eps=1e-6)
    self.norm = norm_layer(channel_hid)

    self.enc = nn.Sequential(*enc_layers)
    dpr = [x.item() for x in torch.linspace(0, 0, 12)]
    self.h = h
    self.w = w
    self.blocks = nn.ModuleList([FourierNetBlock(
        dim=channel_hid,
        mlp_ratio=4,
        drop=0.,
        drop_path=dpr[i],
        act_layer=nn.GELU,
        norm_layer=norm_layer,
        h = self.h,
        w = self.w)
        for i in range(12)
    ])
    self.dec = nn.Sequential(*dec_layers)

def forward(self, x):
    B, T, C, H, W = x.shape
    bias = x
    x = x.reshape(B, T * C, H, W)

    # downsampling
    skips = []
    z = x
    for i in range(self.N_T):
        z = self.enc[i](z)
        if i < self.N_T - 1:
            skips.append(z)

    # Spectral Domain
    B, D, H, W = z.shape
    N = H * W
    z = z.permute(0, 2, 3, 1)
    z = z.view(B, N, D)
    for blk in self.blocks:
        z = blk(z)
    z = self.norm(z).permute(0, 2, 1)

    z = z.reshape(B, D, H, W)

    # upsampling
    z = self.dec[0](z)
    for i in range(1, self.N_T):
        z = self.dec[i](torch.cat([z, skips[-i]], dim=1))

    y = z.reshape(B, T, C, H, W)
    return y + bias
li0759 commented 3 months ago

感谢您的回答,我是留意到论文里有扩展性实验的描述 “Scalability analysis In our implementation, we can quickly expand the model size by stacking blocks layers. As shown in Tab 2, we further explore the scalability of the temporal and spatial module. As meteorological exhibit highly nonlinear and chaotic characteristics, we selected the SEVIR to analyze the scalability of temporal component. We conduct experiments by selecting 2 ∼ 14 TeDev blocks layers, under settings with batch size as 16, training epochs as 300, and learning rate as 0.01 (Adam optimizer). The visualization results are presented quantitatively in the Fig 5. With the increase of the number of TeDev Blocks, we can observe a gradual decrease in the MSE index, and we can easily find that loacl details are becoming clearer as time module gradually increases. This phenomenon once again demonstrates the size scalability of our model.” 想问一下论文里的TeDev blocks layers是不是就是您回答中提到的“傅里叶模块的层数”?也就是说您论文中扩展性实验结果Figure 5就是傅里叶模块的层数从2到14的结果,实验不包括MFSC的层数的变化;我这样理解可以吗?

easylearningscores commented 2 months ago

您好,是这样的,有任何问题,欢迎随时联系