Open KhaledAlkilane89 opened 8 months ago
Q1, Q3: Mamba is a module that has the same interface as a Transformer multi-head attention block. You should use the same approach and format that you would for a Transformer baseline.
Q2: Please read Section 3 of the paper carefully. It is the whole block.
Thank you for your prompt response! I'm curious if Mamba has been evaluated on any time series forecasting benchmarks. If so, I'd greatly appreciate any links or tutorials you could share. Additionally, I'd be grateful if you could review the code snippet below. I'm implementing a PatchTST-like structure with for time series forecasting using Mamba instead of Transformer encoder. However, the loss I'm encountering exhibits an unusual behavior: it decreases for the first two epochs and then starts increasing, unlike the Transformer which exhibits consistent decrease and reach to approximately 0.384. I'm unsure what's causing this divergence. I'm using the following configuration: Dataset: ETTh1 n_vars: 7 batch: 32 seq_len: 96 pred_len: 96 patch_size: 12 stride: 12 lr: 3e-3 Optimizer: AdamW Loss function: L1Loss
from mamba_ssm import Mamba
from utils import RevIN, positional_encoding
class Model(nn.Module):
def __init__(self, input, seq_len, pred_len, d_model, d_state=16, d_conv=4, expand=2, patch_size=12, stride=12, dropout=0.2):
super(Model, self).__init__()
self.n_vars = input
self.seq_len = seq_len
self.pred_len = pred_len
self.d_model = d_model
self.d_state = d_state
self.d_conv = d_conv
self.expand = expand
self.patch_size = patch_size
self.stride = stride
self.num_patches = self.compute_num_patches()
self.mamba = Mamba(d_model, d_state, d_conv, expand)
self.revin = RevIN(input)
self.dropout = nn.Dropout(dropout)
self.act = nn.GELU()
self.flat = nn.Flatten(start_dim=-2)
self.W_P = nn.Linear(patch_size, d_model)
self.W_pos = positional_encoding('sincos', True, self.num_patches, d_model)
self.head = nn.Linear(d_model * self.num_patches, pred_len)
def compute_num_patches(self):
return ((self.seq_len - self.patch_size) // self.stride) + 1
def forward(self, x):
batch_size = x.size(0) # x: [batch, seq_len, n_vars]
# RevIN "norm"
z = x.transpose(-1, -2) # z: [batch, n_vars, seq_len]
z, reverse_fn = self.revin(z)
# Creating patches # [batch, n_vars, seq_len] --> [batch, n_vars, num_patches, patch_size]
z = z.unfold(dimension=-1, size=self.patch_size, step=self.stride)
# Projection from patch_size to d_model # [batch, n_vars, seq_len] --> [batch, n_vars, num_patches, d_model]
# z = z.reshape(batch_size, self.n_vars, self.num_patches, self.patch_size)
z = self.W_P(z)
# Positional encoding # [batch * n_vars, num_patches, d_model] --> [batch * n_vars, num_patches, d_model]
z = z.reshape(z.shape[0]*z.shape[1],z.shape[2],z.shape[3])
z = self.dropout(z + self.W_pos)
# Mamba layer # [batch * n_vars, num_patches, d_model] --> [batch * n_vars, num_patches, d_model]
m_out = self.mamba(z)
m_out = m_out.reshape(batch_size, self.n_vars, self.d_model, self.num_patches) # [batch, n_vars, d_model, num_patches]
# Prediction head
output = self.flat(m_out) # [batch, n_vars, d_model * num_patches]
output = self.head(output) # [batch, n_vars, d_model * num_patches] --> [batch, n_vars, pred_len]
# RevIN "denorm"
output = reverse_fn(output)
output = output.transpose(-1, -2)
return output
Q1, Q3: Mamba is a module that has the same interface as a Transformer multi-head attention block. You should use the same approach and format that you would for a Transformer baseline.
Q2: Please read Section 3 of the paper carefully. It is the whole block.
I'm new to Mamba and have some questions about its functionality.
From my understanding, both during training and evaluation, the Mamba Network (a stack of class Block
in mambasimple.py) receives inputs with the shape (B, L, D).
Since Mamba is an RNN-like model, I'm wondering if there is a process similar to ($x{t-1},\ h_{t-1}$->model()->($x_t,\ h_t$), where this process occurs if it exists. when and how the inference_params
should be used.(utils.generation.py is a little hard)(Will it used in training?)
Thanks if there are any simpler tutorials available.
Hi there, thanks for this amazing work! I'm trying to use Mamba for multivariate time series forecasting, but I'm encountering some issues and could really use your help. Here's what I'm working with: 1) Input Shape: I'm feeding Mamba with data shaped [batch, length, dim], where dim represents the number of variables in my dataset. Is this the correct approach, or should I project this to a higher dimension using the two projection layers within the Mamba block? 2) Mamba Block Clarification: In the from mamba_ssm import Mamba statement, does "Mamba" represent the entire block depicted in Figure 3 of the paper, or just the SSM block within that figure? If it's the whole block, does "expand" refer to the hidden state size or the number of layers? 3) Patching with Mamba: Patching long time series sequences into smaller chunks works well with transformers, so I'm trying the same approach with Mamba. However, I'm unsure what the appropriate input format for the Mamba layer should be after patching. The patched time series data will have a 4D shape of [batch, n_vars, num_patches, patch_size]. I've included the Figure 3 image from the paper for reference: