Closed kolincc closed 4 months ago
Thank you for your attention. You can try the following implementation code. Use self.no_sparse
to control whether to apply the sparse technique.
class Model(nn.Module):
def __init__(self, configs):
super(Model, self).__init__()
# get parameters
self.seq_len = configs.seq_len
self.pred_len = configs.pred_len
self.enc_in = configs.enc_in
self.period_len = configs.period_len
self.model_type = configs.model_type
# self.no_sparse = True
self.no_sparse = False
if self.no_sparse:
self.input = nn.Linear(1, 64)
transformer_encoder_layer = nn.TransformerEncoderLayer(d_model=64, nhead=2, batch_first=True,
dim_feedforward=128)
self.transformer = nn.TransformerEncoder(encoder_layer=transformer_encoder_layer, num_layers=3)
self.pe = PositionalEmbedding(64)
self.output = nn.Linear(64, self.pred_len)
else:
self.seg_num_x = self.seq_len // self.period_len
self.seg_num_y = self.pred_len // self.period_len
self.conv1d = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=1 + 2 * self.period_len//2,
stride=1, padding=self.period_len//2, padding_mode="zeros", bias=False)
self.input = nn.Linear(1, 64)
transformer_encoder_layer = nn.TransformerEncoderLayer(d_model=64, nhead=2, batch_first=True,
dim_feedforward=128)
self.transformer = nn.TransformerEncoder(encoder_layer=transformer_encoder_layer, num_layers=3)
self.pe = PositionalEmbedding(64)
self.output = nn.Linear(64, self.seg_num_y)
def forward(self, x):
batch_size = x.shape[0]
if self.no_sparse:
seq_mean = torch.mean(x, dim=1).unsqueeze(1)
x = (x - seq_mean)
x = x.permute(0, 2, 1).reshape(-1, self.seq_len, 1)
x = self.input(x) + self.pe(x) # bc,s,d
z = torch.mean(self.transformer(x), dim=1) # bc, d
y = self.output(z).view(-1, self.enc_in, self.pred_len).permute(0, 2, 1)
y = y + seq_mean
else:
# normalization and permute b,s,c -> b,c,s
seq_mean = torch.mean(x, dim=1).unsqueeze(1)
x = (x - seq_mean).permute(0, 2, 1)
x = self.conv1d(x.reshape(-1, 1, self.seq_len)).reshape(-1, self.enc_in, self.seq_len) + x
# b,c,s -> bc,n,w -> bc,w,n -> bcw,n,1
x = x.reshape(-1, self.seg_num_x, self.period_len).permute(0, 2, 1).reshape(-1, self.seg_num_x, 1)
x = self.input(x) + self.pe(x) # bcw,n,d
z = torch.mean(self.transformer(x), dim=1) # bcw, d
y = self.output(z).view(-1, self.period_len, self.seg_num_y) # bc, w, m
# bc,w,m -> bc,m,w -> b,c,s
y = y.permute(0, 2, 1).reshape(batch_size, self.enc_in, self.pred_len)
y = y.permute(0, 2, 1) + seq_mean
return y
Thank you for your reply and providing the example code. I used the code you provided to re-experiment, and the results are consistent with the effects shown in Table 5 of the paper. I am particularly grateful for your generous sharing. Comparing the code you provided, I found that the main modifications are an additional dimension transformation operation when down-sampling the input data, that is, "bc,w,n -> bcw,n,1", as well as the setting of the linear layer "nn.Linear(1, 64)" and the parameter settings of the TransformerEncoderLayer. What is the purpose of this? Suppose I now have a prediction model based on the transformer that I have modified myself, can I also transfer the sparse method well?
The main modifications are an additional dimension transformation operation when down-sampling the input data, that is, "bc,w,n -> bcw,n,1"
Using the Sparse technique or not results in the following key differences:
Logically, the core difference between these two is that the former models the original single-channel sequence, while the latter models the sub-sequence after down-sampling by period w.
the setting of the linear layer "nn.Linear(1, 64)" and the parameter settings of the TransformerEncoderLayer. What is the purpose of this?
Whether or not the Sparse technique is used, these parameter settings are consistent. As mentioned above, the hallmark of using the Sparse technique is only the difference in the second dimension of the Transformer input. All other configurations are the same. So, there is no additional purpose here, only to compare the difference with and without the Sparse technology.
Suppose I now have a prediction model based on the transformer that I have modified myself, can I also transfer the sparse method well?
Have you designed a variant of the Transformer? Whether the Sparse technique is effective in this case maybe require further experimental exploration by youself. In my personal opinion, the Sparse technique should provide some improvements for datasets with clear periodicity and over-designed models. In other cases, it may lead to some performance degradation because it significantly reduces the model's parameter count and capacity.
Using the Sparse technique or not results in the following key differences: 1.Without using the Sparse technique, the input dimension of the Transformer Encoder is (bc, s, 1). Note that the channel-independent strategy with parameter-sharing is used here, so the channel c is moved to the batch size dimension. s is the sequence length dimension, and 1 is the channel number dimension. 2.With the Sparse technique, the input dimension becomes (bcw, n, 1). The difference here is that through a reshape, the s dimension is transformed into (n, w), then transposed to (w, n), and then the w dimension is moved to the batch size dimension. Thus, at this point, the Transformer needs to process sequences whose length has changed from s to n. This change is the core technology of SparseTSF, i.e., Cross Period Sparse Forecasting. Logically, the core difference between these two is that the former models the original single-channel sequence, while the latter models the sub-sequence after down-sampling by period w. Whether or not the Sparse technique is used, these parameter settings are consistent. As mentioned above, the hallmark of using the Sparse technique is only the difference in the second dimension of the Transformer input. All other configurations are the same. So, there is no additional purpose here, only to compare the difference with and without the Sparse technology.
I have understood, thank you for your patient and detailed response, which has greatly benefited me. I will continue to pay attention to your work, and once again, I am grateful for your sharing!
Have you designed a variant of the Transformer? Whether the Sparse technique is effective in this case maybe require further experimental exploration by youself. In my personal opinion, the Sparse technique should provide some improvements for datasets with clear periodicity and over-designed models. In other cases, it may lead to some performance degradation because it significantly reduces the model's parameter count and capacity.
Yes, I am currently conducting research on the application of Transformer variants in the field of time series forecasting. Your work has been of great assistance to me. I will attempt to integrate your technique into my current approach for testing to see its specific effects. I appreciate your selfless sharing and explanations, and I wish you success in your scientific research.
I replaced the Linear Layer in the prediction part with a traditional transformer, but the results were not satisfactory (ETTh1 dataset). The performance is far from the indicators of the transformer+sparse boost in Table 5 of the paper. May I ask how the author conducted this experiment? Here are my experimental results as follows: ETTh1_720_96_SparseTSF_ETTh1_ftM_sl720_pl96_test_0_seed2023
mse:0.6976203322410583, mae:0.560204029083252, rse:0.7933546900749207
ETTh1_720_192_SparseTSF_ETTh1_ftM_sl720_pl192_test_0_seed2023
mse:0.7232381701469421, mae:0.5843048095703125, rse:0.807603120803833
ETTh1_720_336_SparseTSF_ETTh1_ftM_sl720_pl336_test_0_seed2023
mse:0.7192806601524353, mae:0.5902762413024902, rse:0.8074232339859009
ETTh1_720_720_SparseTSF_ETTh1_ftM_sl720_pl720_test_0_seed2023
mse:0.7250006794929504, mae:0.6054915189743042, rse:0.8151193261146545