Open tRosenflanz opened 6 months ago
Interesting, could you share an example code or notebook?
I wanted to make it a proper branch for testing but haven't had a chance. If you want to run your own experiments you can add SimpleSkip to forward
of _ConditionalMixerLayer
class _ConditionalMixerLayer(nn.Module):
def __init__(
self,
sequence_length: int,
input_dim: int,
output_dim: int,
static_cov_dim: int,
ff_size: int,
activation: Callable,
dropout: float,
normalize_before: bool,
norm_type: nn.Module,
) -> None:
"""Conditional mix layer combining time and feature mixing with static context based on the
`PyTorch implementation of TSMixer <https://github.com/ditschuk/pytorch-tsmixer>`_.
This module combines time mixing and conditional feature mixing, where the latter
is influenced by static features. This allows the module to learn representations
that are influenced by both dynamic and static features.
Parameters
----------
sequence_length
The length of the input sequences.
input_dim
The number of input channels of the dynamic features.
output_dim
The number of output channels after feature mixing.
static_cov_dim
The number of channels in the static feature input.
ff_size
The inner dimension of the feedforward network used in feature mixing.
activation
The activation function used in both mixing operations.
dropout
The dropout probability used in both mixing operations.
normalize_before
Whether to apply normalization before or after mixing.
norm_type
The type of normalization to use.
"""
super().__init__()
mixing_input = input_dim
if static_cov_dim != 0:
self.feature_mixing_static = _FeatureMixing(
sequence_length=sequence_length,
input_dim=static_cov_dim,
output_dim=output_dim,
ff_size=ff_size,
activation=activation,
dropout=dropout,
normalize_before=normalize_before,
norm_type=norm_type,
)
mixing_input += output_dim
else:
self.feature_mixing_static = None
self.time_mixing = _TimeMixing(
sequence_length=sequence_length,
input_dim=mixing_input,
activation=activation,
dropout=dropout,
normalize_before=normalize_before,
norm_type=norm_type,
)
self.feature_mixing = _FeatureMixing(
sequence_length=sequence_length,
input_dim=mixing_input,
output_dim=output_dim,
ff_size=ff_size,
activation=activation,
dropout=dropout,
normalize_before=normalize_before,
norm_type=norm_type,
)
self.skip_connection = SimpleSkip(
sequence_length=sequence_length,
input_dim=mixing_input,
output_dim=output_dim
)
def forward(
self, x_inp: torch.Tensor, x_static: Optional[torch.Tensor]
) -> torch.Tensor:
if self.feature_mixing_static is not None:
x_static_mixed = self.feature_mixing_static(x_static)
x = torch.cat([x_inp, x_static_mixed], dim=-1)
else:
x = x_inp
x = self.time_mixing(x)
x = self.feature_mixing(x)
if self.skip_connection: x = self.skip_connection(x_inp,x)
return x
Is your feature request related to a current problem? Please describe. TSMixer Model with num_blocks higher than 4 aren't training well. It is somewhat nebulous to pinpoint but higher number of blocks can lead to much worse results. In my dataset, anything with num_blocks of 8 remains stagnant at extremely suboptimal metrics. Even in simpler datasets like ETTh, it leads to worse results although it is easier to attribute to overfitting.
There is no clear mention of this in the original paper, but they do not train deeper "extended" models (the type implemented in Darts) according to the benchmarks.
Describe proposed solution
Through some experimentation, simple skip connection in the Conditional Mixer layer that combines the input and the output of time+feature mixer layers greatly alleviates this issue. This operation isn't in the original paper, but seems like a simplistic way to extend the functionality without modifying the general architecture.
There are a few ways to implement skip connections and since it isn't the default choice, it must remain optional. Adding a new argument for mixing_skip_connection_cls that is instantiated by the ConditionalMixerLayer if specified seems to work quite cleanly. Darts can even provide the simplest variation of
x_inp + x_processed
as one of thestr
variants. I have tried the recursive variant from https://aclanthology.org/2020.coling-main.320.pdf and is quite effective in my dataset.Describe potential alternatives
Additional context Results from Etth this is a very exaggerated example: num_blocks = 2 vs s num_blocks=16 vsnum_blocks=2 + simple skip vs num_blocks=16 + simple skip
clearly deep no skip underperforms compared to the rest
I can post training curves from my dataset with/without skip connections showing the improvements but can't share the data. The dataset is fairly similar to M5 though
skip implementation