state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
13.19k stars 1.12k forks source link

Training a text classifier on Mamba #163

Open mdabedr opened 9 months ago

mdabedr commented 9 months ago

Hello. Can you please some insights on how one can train a text classifier with Mamba

albertfgu commented 9 months ago

The same way you would do it with a Transformer. There might be two differences that come to mind:

mdabedr commented 9 months ago

Thank you for your reply. Is the MambaLMHeadModel defined in the mixer_seq_simple.py useful in this particular application?

tridao commented 9 months ago

You'd probably want to write a MambaClassifierHeadModel that has a similar structure: a Mamba model backbone with a classifier head.

yudizhangzyd commented 9 months ago

Trying to do the same task here, but an error occurs:

~/anaconda3/envs/myenv/lib/python3.10/site-packages/mamba_ssm/modules/mamba_simple.py:136, in Mamba.forward(self, hidden_states, inference_params)
    132         return out
    134 # We do matmul and transpose BLH -> HBL at the same time
    135 xz = rearrange(
--> 136     self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
    137     "d (b l) -> b d l",
    138     l=seqlen,
    139 )
    140 if self.in_proj.bias is not None:
    141     xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")

RuntimeError: mat1 and mat2 shapes cannot be multiplied (4x1 and 12x1920)

My input is (128, 15, 12), and the Mamba model is configured as

d_model=12,  # Model dimension d_model
d_state=16,   # SSM state expansion factor
d_conv=4,     # Local convolution width
expand=2      # Block expansion factor

any insights on this? thanks

albertfgu commented 9 months ago

The inputs to the Mamba block probably aren't formatted correctly. Check the example in the README and double check all your shapes.

yudizhangzyd commented 9 months ago

Thats weird, my input data successfully run through LSTM/Transformers

maksymdolgikh commented 9 months ago

One way is to use pretrained models from hugging face together with transformers library

import torch
from transformers import AutoTokenizer
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel

num_labels = 2  # the number of labels

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-130m")

model.lm_head = torch.nn.Linear(model.config.d_model, num_labels)

From here, you fine-tune the resulting model on your classification task. This approach was used with great success on kaggle, where you can find more details: https://www.kaggle.com/competitions/llm-detect-ai-generated-text/discussion/470093