Open DanFosing opened 1 month ago
I experimented with my config a bit and it seems like setting head_dim to 64 and increasing num_heads to 32, while also setting n_groups to 1 may be fixing the issue, currently managed to train up to 200M tokens without any problems (and the loss is a bit higher (by 0.1) than mamba-1 currently but there are no grad norm issues), also my friend managed to fix it by using slightly different configuration (head dim set to 128, num heads set to 16), also by setting n_groups to 1, is this behavior expected and n_groups should be set higher than 1 only for bigger models, or is there an actual problem with it?
700M tokens in and it seems like indeed setting n_groups to 1 fixes the issue, the loss is still worse by 1.8%, but I guess it may be normal
I think ngroups was meant to be used based on the number of gpus during training, see https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/modules/mamba2.py#L82
Mistral very likely used a multi gpu setting in their case. Otherwise, the original mamba2 models all use ngroups=1 (from 130m up to 2.7b). I'm a bit surprised that it causes exploding gradients though, would have to dig deeper into why exactly. Does it use the fast or slow path in your case?
Maybe it would work if I set n_groups to the number of gpus, I haven't tried that actually. How can I check which path it uses?
You would only need to care about it being a mutltiple of the number of gpus. But there is also a difference in the current implementation of transformers i would guess that each gpu gets ngroup = config.ngroup which means that your effective ngroups = config.ngroups * num_gpus (but im not too familiar with how transformers handles mutli gpu).
In the original repo you would distribute config.ngroups across the gpus so that the effective ngroups remains the same as in the config.
In short what i want to say: Your effective ngroups is very high with 8*2=16. Have you tried ngroups=4 to achieve an effective ngroups of 8? Keep in mind that if you load in multi gpu settings that the ngroups config has to be adjusted (e.g. single gpu would need ngroups=8 again).
You would get a warning during your run if the slow path was to be run reminding to install the original mamba_ssm package. If there was no such warning, you were probably in the fast path.
I tried using n_groups of 4, I tried running the training on only one gpu with n_groups=4, and the gradients still exploded. And yeah I'm using the fast path. Also what's really interesting is that I'm getting a warning that I'm using ddp_find_unused_parameters while it didn't fins any unused parameters, but I suppose it's just a false positive.
I guess there is not much to do since it's still active research so some settings may be undesired just as you have discovered (ngroups). Maybe the original authors can answer you on the role of ngroups in more detail.
Can't help with the ddp warning as that's pytorch territory ^^ but it does have some buggy stuff from time to time.
Also, have you tried clipping gradients? I think the original training settings in mamba(2) employed it so that might stabilize the training in general even with higher ngroups (but just a guess).
Edit: nvm seems to be default in trainer to clip gradients..
I'm not sure if that's what you mean, but as you can see in code snippets I uploaded, I'm using max_grad_norm=1.0, but it doesn't really help as after some time grad_norm is infinity so even while it's clipped it works worse as grad_norm is always at its max value
This is how grad_norm looks like in my hybrid architecture after setting n_groups to 1 (purple), and light blue one is exactly the same arch but with mamba-1 (ofc I changed initialization so it works with mamba-2 well) It's not too stable but it doesn't go over gradient clipping value so it seems fine (+ it's getting more stable with time, each step is 1M tokens)
Yup, seems like ngroups = 1 is viable in that case. Sorry about the confusion about gradient clipping, didn't know that it was default in the trainer.
And this is how it looked with n_groups set to 8
I'm wondering if it's somehow expected or if there is an error in the implementation
Understandable, but i think this is a question that rather belongs to the original repo than transformers.
Indeed! And thanks @vasqu for take the lead and answering 🤗
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
System Info
transformers
version: 4.44.0Who can help?
@ArthurZucker @muellerz @SunMarc
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Load any training/ eval dataset you want, tokenize, split into 2k parts
Use the model config from mamba-2 configuration with those settings (the rest is default):
Set training arguments and trainer settings like that:
trainer = Trainer( model=model, args=training_args, train_dataset=dataset, eval_dataset=eval_dataset, data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False), )