state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
12.8k stars 1.08k forks source link

Training Script #1

Open loretoparisi opened 10 months ago

loretoparisi commented 10 months ago

It would be worth to provide a train script, in order to train larger models (for instance 7B, 13B).

tridao commented 10 months ago

You can use whichever training script / library you'd like, e.g. Megatron, DeepSpeed, lightning, hf accelerate etc. Just have to replace the model definition.

Examples: Lightning has lit-gpt: https://github.com/Lightning-AI/lit-gpt FlashAttention has training code, you can swap the model: https://github.com/Dao-AILab/flash-attention/tree/main/training

geronimi73 commented 10 months ago

does not seem to be so straightforward with HF trainer, quite literally:

  File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1555, in train
    return inner_training_loop(
  File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1860, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 2725, in training_step
    loss = self.compute_loss(model, inputs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 2748, in compute_loss
    outputs = model(**inputs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py", line 680, in forward
    return model_forward(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py", line 668, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 14, in decorate_autocast
    return func(*args, **kwargs)
TypeError: MambaLMHeadModel.forward() got an unexpected keyword argument 'labels'

no labels in forward() ?

would be very nice if you could provide a simple, minimal example of how to use the models with HF trainer. thank you!

justusmattern27 commented 10 months ago

We've managed to train mamba by modifying the Huggingface Trainer class. Here is our implementation, we were actually able to train a chat model that seems to perform quite well.

binxuan commented 10 months ago

We've managed to train mamba by modifying the Huggingface Trainer class. Here is our implementation, we were actually able to train a chat model that seems to perform quite well.

Cool, nice work! Are you using fp32 for this finetuning work?

Calvinnncy97 commented 9 months ago

Hmm... Doesn't seem to work out of the box with lit-gpt.

Minimal example:

from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from mamba_ssm.utils.hf import load_config_hf
from mamba_ssm.models.config_mamba import MambaConfig

with fabric.init_module(
        empty_init=isinstance(fabric.strategy, DeepSpeedStrategy)
    ):
  config = load_config_hf('state-spaces/mamba-2.8b')
  model = MambaLMHeadModel(MambaConfig(**config))

This will give the following error

Traceback (most recent call last):
  File "/home/me/lit-gpt/pretrain/mamba.py", line 782, in <module>
    setup(run_config)
  File "/home/me/lit-gpt/pretrain/mamba.py", line 306, in setup
    main(fabric, run_config)
  File "/home/me/lit-gpt/pretrain/mamba.py", line 342, in main
    model = MambaLMHeadModel(MambaConfig(**model_config))
  File "/home/me/lit-gpt/mamba-venv/lib/python3.9/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 385, in wrapper
    f(module, *args, **kwargs)
  File "/home/me/lit-gpt/mamba-venv/lib/python3.9/site-packages/mamba_ssm/models/mixer_seq_simple.py", line 199, in __init__
    self.backbone = MixerModel(
  File "/home/me/lit-gpt/mamba-venv/lib/python3.9/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 385, in wrapper
    f(module, *args, **kwargs)
  File "/home/me/lit-gpt/mamba-venv/lib/python3.9/site-packages/mamba_ssm/models/mixer_seq_simple.py", line 118, in __init__
    [
  File "/home/me/lit-gpt/mamba-venv/lib/python3.9/site-packages/mamba_ssm/models/mixer_seq_simple.py", line 119, in <listcomp>
    create_block(
  File "/home/me/lit-gpt/mamba-venv/lib/python3.9/site-packages/mamba_ssm/models/mixer_seq_simple.py", line 42, in create_block
    block = Block(
  File "/home/me/lit-gpt/mamba-venv/lib/python3.9/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 385, in wrapper
    f(module, *args, **kwargs)
  File "/home/me/lit-gpt/mamba-venv/lib/python3.9/site-packages/mamba_ssm/modules/mamba_simple.py", line 316, in __init__
    self.mixer = mixer_cls(dim)
  File "/home/me/lit-gpt/mamba-venv/lib/python3.9/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 385, in wrapper
    f(module, *args, **kwargs)
  File "/home/me/lit-gpt/mamba-venv/lib/python3.9/site-packages/mamba_ssm/modules/mamba_simple.py", line 99, in __init__
    self.dt_proj.bias.copy_(inv_dt)
RuntimeError: The size of tensor a (0) must match the size of tensor b (5120) at non-singleton dimension 0
thistleknot commented 9 months ago

if that's true, how the heck am I to pass attention?

Calvinnncy97 commented 9 months ago

Stage 2 works, but not stage 3. I don't have a fix at the moment. Problem is this line https://github.com/state-spaces/mamba/blob/eb2f7a520dd5e2949b7ae1c3ef44f6cb99faef5c/mamba_ssm/modules/mamba_simple.py#L98