amazon-science / earth-forecasting-transformer

Official implementation of Earthformer
Apache License 2.0
359 stars 61 forks source link

Prediction accuracy using earthformer is lower than the rainformer issue #67

Open Helomin opened 8 months ago

Helomin commented 8 months ago

Both use the sevir dataset to predict for two hours Using the adamw optimizer with a learning rate of 1e-4, none of them used the task learning rate optimization strategy Here are the earthformer model parameter settings: base_units=128, block_units=None, scale_alpha=1.0, num_heads=4, attn_drop=0.0, proj_drop=0.0, ffn_drop=0.0,

inter-attn downsample/upsample

downsample=2, downsample_type='patch_merge', upsample_type="upsample", upsample_kernel_size=3,

encoder

enc_depth=[2, 2], enc_attn_patterns=None, enc_cuboid_size=[(4, 4, 4), (4, 4, 4)], enc_cuboid_strategy=[('l', 'l', 'l'), ('d', 'd', 'd')], enc_shift_size=[(0, 0, 0), (0, 0, 0)], enc_use_inter_ffn=True,

decoder

dec_depth=[2, 2], dec_cross_start=0, dec_self_attn_patterns=None, dec_self_cuboid_size=[(4, 4, 4), (4, 4, 4)], dec_self_cuboid_strategy=[('l', 'l', 'l'), ('d', 'd', 'd')], dec_self_shift_size=[(1, 1, 1), (0, 0, 0)], dec_cross_attn_patterns=None, dec_cross_cuboid_hw=[(4, 4), (4, 4)], dec_cross_cuboid_strategy=[('l', 'l', 'l'), ('d', 'l', 'l')], dec_cross_shift_hw=[(0, 0), (0, 0)], dec_cross_n_temporal=[1, 2], dec_cross_last_n_frames=None, dec_use_inter_ffn=True, dec_hierarchical_pos_embed=False,

global vectors

num_global_vectors=8, use_dec_self_global=False, dec_self_update_global=True, use_dec_cross_global=False, use_global_vector_ffn=False, use_global_self_attn=True, separate_global_qkv=True, global_dim_ratio=1, z_init_method='zeros',

initial downsample and final upsample

initial_downsample_type="stack_conv", initial_downsample_activation="leaky",

initial_downsample_type=="conv"

initial_downsample_scale=1, initial_downsample_conv_layers=2, final_upsample_conv_layers=2,

initial_downsample_type == "stack_conv"

initial_downsample_stack_conv_num_layers=3, initial_downsample_stack_conv_dim_list=[16, 64, 128], # [96, 384, 768] initial_downsample_stack_conv_downscale_list=[3, 2, 2], initial_downsample_stack_conv_num_conv_list=[2, 2, 2],

end of initial downsample and final upsample

ffn_activation='gelu', gated_ffn=False, norm_layer='layer_norm', padding_type='ignore', pos_embed_type='t+hw', checkpoint_level=0, use_relative_pos=True, self_attn_use_final_proj=True, dec_use_first_self_attn=False,

initialization

attn_linear_init_mode="0", ffn_linear_init_mode="0", conv_init_mode="0", down_up_linear_init_mode="0", norm_init_mode="0",