Open loretoparisi opened 10 months ago
You can use whichever training script / library you'd like, e.g. Megatron, DeepSpeed, lightning, hf accelerate etc. Just have to replace the model definition.
Examples: Lightning has lit-gpt: https://github.com/Lightning-AI/lit-gpt FlashAttention has training code, you can swap the model: https://github.com/Dao-AILab/flash-attention/tree/main/training
does not seem to be so straightforward with HF trainer, quite literally:
File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1555, in train
return inner_training_loop(
File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1860, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 2725, in training_step
loss = self.compute_loss(model, inputs)
File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 2748, in compute_loss
outputs = model(**inputs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py", line 680, in forward
return model_forward(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py", line 668, in __call__
return convert_to_fp32(self.model_forward(*args, **kwargs))
File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 14, in decorate_autocast
return func(*args, **kwargs)
TypeError: MambaLMHeadModel.forward() got an unexpected keyword argument 'labels'
no labels
in forward()
?
would be very nice if you could provide a simple, minimal example of how to use the models with HF trainer. thank you!
We've managed to train mamba by modifying the Huggingface Trainer class. Here is our implementation, we were actually able to train a chat model that seems to perform quite well.
We've managed to train mamba by modifying the Huggingface Trainer class. Here is our implementation, we were actually able to train a chat model that seems to perform quite well.
Cool, nice work! Are you using fp32 for this finetuning work?
Hmm... Doesn't seem to work out of the box with lit-gpt.
Minimal example:
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from mamba_ssm.utils.hf import load_config_hf
from mamba_ssm.models.config_mamba import MambaConfig
with fabric.init_module(
empty_init=isinstance(fabric.strategy, DeepSpeedStrategy)
):
config = load_config_hf('state-spaces/mamba-2.8b')
model = MambaLMHeadModel(MambaConfig(**config))
This will give the following error
Traceback (most recent call last):
File "/home/me/lit-gpt/pretrain/mamba.py", line 782, in <module>
setup(run_config)
File "/home/me/lit-gpt/pretrain/mamba.py", line 306, in setup
main(fabric, run_config)
File "/home/me/lit-gpt/pretrain/mamba.py", line 342, in main
model = MambaLMHeadModel(MambaConfig(**model_config))
File "/home/me/lit-gpt/mamba-venv/lib/python3.9/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 385, in wrapper
f(module, *args, **kwargs)
File "/home/me/lit-gpt/mamba-venv/lib/python3.9/site-packages/mamba_ssm/models/mixer_seq_simple.py", line 199, in __init__
self.backbone = MixerModel(
File "/home/me/lit-gpt/mamba-venv/lib/python3.9/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 385, in wrapper
f(module, *args, **kwargs)
File "/home/me/lit-gpt/mamba-venv/lib/python3.9/site-packages/mamba_ssm/models/mixer_seq_simple.py", line 118, in __init__
[
File "/home/me/lit-gpt/mamba-venv/lib/python3.9/site-packages/mamba_ssm/models/mixer_seq_simple.py", line 119, in <listcomp>
create_block(
File "/home/me/lit-gpt/mamba-venv/lib/python3.9/site-packages/mamba_ssm/models/mixer_seq_simple.py", line 42, in create_block
block = Block(
File "/home/me/lit-gpt/mamba-venv/lib/python3.9/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 385, in wrapper
f(module, *args, **kwargs)
File "/home/me/lit-gpt/mamba-venv/lib/python3.9/site-packages/mamba_ssm/modules/mamba_simple.py", line 316, in __init__
self.mixer = mixer_cls(dim)
File "/home/me/lit-gpt/mamba-venv/lib/python3.9/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 385, in wrapper
f(module, *args, **kwargs)
File "/home/me/lit-gpt/mamba-venv/lib/python3.9/site-packages/mamba_ssm/modules/mamba_simple.py", line 99, in __init__
self.dt_proj.bias.copy_(inv_dt)
RuntimeError: The size of tensor a (0) must match the size of tensor b (5120) at non-singleton dimension 0
if that's true, how the heck am I to pass attention?
Stage 2 works, but not stage 3. I don't have a fix at the moment. Problem is this line https://github.com/state-spaces/mamba/blob/eb2f7a520dd5e2949b7ae1c3ef44f6cb99faef5c/mamba_ssm/modules/mamba_simple.py#L98
It would be worth to provide a train script, in order to train larger models (for instance 7B, 13B).