Open mdabedr opened 9 months ago
The same way you would do it with a Transformer. There might be two differences that come to mind:
Thank you for your reply. Is the MambaLMHeadModel defined in the mixer_seq_simple.py useful in this particular application?
You'd probably want to write a MambaClassifierHeadModel that has a similar structure: a Mamba model backbone with a classifier head.
Trying to do the same task here, but an error occurs:
~/anaconda3/envs/myenv/lib/python3.10/site-packages/mamba_ssm/modules/mamba_simple.py:136, in Mamba.forward(self, hidden_states, inference_params)
132 return out
134 # We do matmul and transpose BLH -> HBL at the same time
135 xz = rearrange(
--> 136 self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
137 "d (b l) -> b d l",
138 l=seqlen,
139 )
140 if self.in_proj.bias is not None:
141 xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
RuntimeError: mat1 and mat2 shapes cannot be multiplied (4x1 and 12x1920)
My input is (128, 15, 12), and the Mamba model is configured as
d_model=12, # Model dimension d_model
d_state=16, # SSM state expansion factor
d_conv=4, # Local convolution width
expand=2 # Block expansion factor
any insights on this? thanks
The inputs to the Mamba block probably aren't formatted correctly. Check the example in the README and double check all your shapes.
Thats weird, my input data successfully run through LSTM/Transformers
One way is to use pretrained models from hugging face together with transformers library
import torch
from transformers import AutoTokenizer
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
num_labels = 2 # the number of labels
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-130m")
model.lm_head = torch.nn.Linear(model.config.d_model, num_labels)
From here, you fine-tune the resulting model on your classification task. This approach was used with great success on kaggle, where you can find more details: https://www.kaggle.com/competitions/llm-detect-ai-generated-text/discussion/470093
Hello. Can you please some insights on how one can train a text classifier with Mamba