state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
13.2k stars 1.12k forks source link

Question about Mamba for Multivariate Time Series Forecasting #190

Open KhaledAlkilane89 opened 8 months ago

KhaledAlkilane89 commented 8 months ago

Hi there, thanks for this amazing work! I'm trying to use Mamba for multivariate time series forecasting, but I'm encountering some issues and could really use your help. Here's what I'm working with: 1) Input Shape: I'm feeding Mamba with data shaped [batch, length, dim], where dim represents the number of variables in my dataset. Is this the correct approach, or should I project this to a higher dimension using the two projection layers within the Mamba block? 2) Mamba Block Clarification: In the from mamba_ssm import Mamba statement, does "Mamba" represent the entire block depicted in Figure 3 of the paper, or just the SSM block within that figure? If it's the whole block, does "expand" refer to the hidden state size or the number of layers? 3) Patching with Mamba: Patching long time series sequences into smaller chunks works well with transformers, so I'm trying the same approach with Mamba. However, I'm unsure what the appropriate input format for the Mamba layer should be after patching. The patched time series data will have a 4D shape of [batch, n_vars, num_patches, patch_size]. I've included the Figure 3 image from the paper for reference:

image
albertfgu commented 8 months ago

Q1, Q3: Mamba is a module that has the same interface as a Transformer multi-head attention block. You should use the same approach and format that you would for a Transformer baseline.

Q2: Please read Section 3 of the paper carefully. It is the whole block.

KhaledAlkilane89 commented 8 months ago

Thank you for your prompt response! I'm curious if Mamba has been evaluated on any time series forecasting benchmarks. If so, I'd greatly appreciate any links or tutorials you could share. Additionally, I'd be grateful if you could review the code snippet below. I'm implementing a PatchTST-like structure with for time series forecasting using Mamba instead of Transformer encoder. However, the loss I'm encountering exhibits an unusual behavior: it decreases for the first two epochs and then starts increasing, unlike the Transformer which exhibits consistent decrease and reach to approximately 0.384. I'm unsure what's causing this divergence. I'm using the following configuration: Dataset: ETTh1 n_vars: 7 batch: 32 seq_len: 96 pred_len: 96 patch_size: 12 stride: 12 lr: 3e-3 Optimizer: AdamW Loss function: L1Loss

from mamba_ssm import Mamba
from utils import RevIN, positional_encoding

class Model(nn.Module):
    def __init__(self, input, seq_len, pred_len, d_model, d_state=16, d_conv=4, expand=2, patch_size=12, stride=12, dropout=0.2):
        super(Model, self).__init__()
        self.n_vars = input
        self.seq_len = seq_len
        self.pred_len = pred_len
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.expand = expand
        self.patch_size = patch_size
        self.stride = stride

        self.num_patches = self.compute_num_patches()
        self.mamba = Mamba(d_model, d_state, d_conv, expand)
        self.revin = RevIN(input)
        self.dropout = nn.Dropout(dropout)
        self.act = nn.GELU()
        self.flat = nn.Flatten(start_dim=-2)
        self.W_P = nn.Linear(patch_size, d_model)
        self.W_pos = positional_encoding('sincos', True, self.num_patches, d_model)
        self.head = nn.Linear(d_model * self.num_patches, pred_len)

    def compute_num_patches(self):
        return ((self.seq_len - self.patch_size) // self.stride) + 1

    def forward(self, x):  
        batch_size = x.size(0)                                                                          # x: [batch, seq_len, n_vars]

        # RevIN "norm"
        z = x.transpose(-1, -2)                                                                         # z: [batch, n_vars, seq_len]
        z, reverse_fn = self.revin(z)

        # Creating patches                                                                              # [batch, n_vars, seq_len] --> [batch, n_vars, num_patches, patch_size]
        z = z.unfold(dimension=-1, size=self.patch_size, step=self.stride)   

        # Projection from patch_size to d_model                                                         # [batch, n_vars, seq_len] --> [batch, n_vars, num_patches, d_model]
        # z = z.reshape(batch_size, self.n_vars, self.num_patches, self.patch_size)
        z = self.W_P(z)

        # Positional encoding                                                                          # [batch * n_vars, num_patches, d_model] --> [batch * n_vars, num_patches, d_model] 
        z = z.reshape(z.shape[0]*z.shape[1],z.shape[2],z.shape[3]) 
        z = self.dropout(z + self.W_pos)

        # Mamba layer                                                                                   # [batch * n_vars, num_patches, d_model] --> [batch * n_vars, num_patches, d_model]
        m_out = self.mamba(z)
        m_out = m_out.reshape(batch_size, self.n_vars, self.d_model, self.num_patches)                     # [batch, n_vars, d_model, num_patches]

        # Prediction head                       
        output = self.flat(m_out)                                                                       # [batch, n_vars, d_model * num_patches]                                                       
        output = self.head(output)                                                                      # [batch, n_vars, d_model * num_patches] --> [batch, n_vars, pred_len]

        # RevIN "denorm"
        output = reverse_fn(output)
        output = output.transpose(-1, -2)
        return output 

Screenshot from 2024-02-25 15-18-38

AndssY commented 8 months ago

Q1, Q3: Mamba is a module that has the same interface as a Transformer multi-head attention block. You should use the same approach and format that you would for a Transformer baseline.

Q2: Please read Section 3 of the paper carefully. It is the whole block.

I'm new to Mamba and have some questions about its functionality. From my understanding, both during training and evaluation, the Mamba Network (a stack of class Block in mambasimple.py) receives inputs with the shape (B, L, D). Since Mamba is an RNN-like model, I'm wondering if there is a process similar to ($x{t-1},\ h_{t-1}$->model()->($x_t,\ h_t$), where this process occurs if it exists. when and how the inference_params should be used.(utils.generation.py is a little hard)(Will it used in training?) Thanks if there are any simpler tutorials available.