unit8co / darts

A python library for user-friendly forecasting and anomaly detection on time series.
https://unit8co.github.io/darts/
Apache License 2.0
8.1k stars 882 forks source link

TSMixer ConditionalMixer Skip Connections #2388

Open tRosenflanz opened 6 months ago

tRosenflanz commented 6 months ago

Is your feature request related to a current problem? Please describe. TSMixer Model with num_blocks higher than 4 aren't training well. It is somewhat nebulous to pinpoint but higher number of blocks can lead to much worse results. In my dataset, anything with num_blocks of 8 remains stagnant at extremely suboptimal metrics. Even in simpler datasets like ETTh, it leads to worse results although it is easier to attribute to overfitting.

There is no clear mention of this in the original paper, but they do not train deeper "extended" models (the type implemented in Darts) according to the benchmarks.

Describe proposed solution

Through some experimentation, simple skip connection in the Conditional Mixer layer that combines the input and the output of time+feature mixer layers greatly alleviates this issue. This operation isn't in the original paper, but seems like a simplistic way to extend the functionality without modifying the general architecture.

There are a few ways to implement skip connections and since it isn't the default choice, it must remain optional. Adding a new argument for mixing_skip_connection_cls that is instantiated by the ConditionalMixerLayer if specified seems to work quite cleanly. Darts can even provide the simplest variation of x_inp + x_processed as one of the str variants. I have tried the recursive variant from https://aclanthology.org/2020.coling-main.320.pdf and is quite effective in my dataset.

Screenshot 2024-05-19 at 12 50 20 AM

Describe potential alternatives

Additional context Results from Etth this is a very exaggerated example: num_blocks = 2 vs s num_blocks=16 vsnum_blocks=2 + simple skip vs num_blocks=16 + simple skip

Screenshot 2024-05-19 at 12 49 06 AM

clearly deep no skip underperforms compared to the rest

I can post training curves from my dataset with/without skip connections showing the improvements but can't share the data. The dataset is fairly similar to M5 though

skip implementation

class Skip(nn.Module):
    def __init__(self, sequence_length,input_dim, output_dim,*args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.projection =  nn.Linear(input_dim,output_dim) if input_dim != output_dim else nn.Identity()

class SimpleSkip(Skip):
    def forward(self,x_original,x_processed,):
        x_original = self.projection(x_original)
        return x_processed+x_original
flight505 commented 6 months ago

Interesting, could you share an example code or notebook?

tRosenflanz commented 4 months ago

I wanted to make it a proper branch for testing but haven't had a chance. If you want to run your own experiments you can add SimpleSkip to forward of _ConditionalMixerLayer

class _ConditionalMixerLayer(nn.Module):
    def __init__(
        self,
        sequence_length: int,
        input_dim: int,
        output_dim: int,
        static_cov_dim: int,
        ff_size: int,
        activation: Callable,
        dropout: float,
        normalize_before: bool,
        norm_type: nn.Module,
    ) -> None:
        """Conditional mix layer combining time and feature mixing with static context based on the
        `PyTorch implementation of TSMixer <https://github.com/ditschuk/pytorch-tsmixer>`_.

        This module combines time mixing and conditional feature mixing, where the latter
        is influenced by static features. This allows the module to learn representations
        that are influenced by both dynamic and static features.

        Parameters
        ----------
        sequence_length
            The length of the input sequences.
        input_dim
            The number of input channels of the dynamic features.
        output_dim
            The number of output channels after feature mixing.
        static_cov_dim
            The number of channels in the static feature input.
        ff_size
            The inner dimension of the feedforward network used in feature mixing.
        activation
            The activation function used in both mixing operations.
        dropout
            The dropout probability used in both mixing operations.
        normalize_before
            Whether to apply normalization before or after mixing.
        norm_type
            The type of normalization to use.
        """
        super().__init__()

        mixing_input = input_dim
        if static_cov_dim != 0:
            self.feature_mixing_static = _FeatureMixing(
                sequence_length=sequence_length,
                input_dim=static_cov_dim,
                output_dim=output_dim,
                ff_size=ff_size,
                activation=activation,
                dropout=dropout,
                normalize_before=normalize_before,
                norm_type=norm_type,
            )
            mixing_input += output_dim
        else:
            self.feature_mixing_static = None

        self.time_mixing = _TimeMixing(
            sequence_length=sequence_length,
            input_dim=mixing_input,
            activation=activation,
            dropout=dropout,
            normalize_before=normalize_before,
            norm_type=norm_type,
        )
        self.feature_mixing = _FeatureMixing(
            sequence_length=sequence_length,
            input_dim=mixing_input,
            output_dim=output_dim,
            ff_size=ff_size,
            activation=activation,
            dropout=dropout,
            normalize_before=normalize_before,
            norm_type=norm_type,
        )
        self.skip_connection = SimpleSkip(
            sequence_length=sequence_length,
            input_dim=mixing_input,
            output_dim=output_dim
       )

    def forward(
        self, x_inp: torch.Tensor, x_static: Optional[torch.Tensor]
    ) -> torch.Tensor:
        if self.feature_mixing_static is not None:
            x_static_mixed = self.feature_mixing_static(x_static)
            x = torch.cat([x_inp, x_static_mixed], dim=-1)
        else:
            x = x_inp
        x = self.time_mixing(x)
        x = self.feature_mixing(x)
        if self.skip_connection: x = self.skip_connection(x_inp,x)
        return x