huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
131.99k stars 26.29k forks source link

Mamba-2 Exploding Gradients #32570

Open DanFosing opened 1 month ago

DanFosing commented 1 month ago

System Info

Who can help?

@ArthurZucker @muellerz @SunMarc

Information

Tasks

Reproduction

  1. Load any training/ eval dataset you want, tokenize, split into 2k parts

  2. Use the model config from mamba-2 configuration with those settings (the rest is default):

    config = config_class(
        vocab_size=32000,
        hidden_size=1024,
        num_hidden_layers=12,
        head_dim=128,
        expand=2,
        num_heads=16,
        n_groups=8,
        state_size=128,
        use_cache=False,
        is_training=True,
    residual_in_fp32=True,
    norm_before_gate=False,
    rms_norm=True, # it can be False, it makes the issue appear a bit later but it still does appear
    )
  3. Set training arguments and trainer settings like that:

    
    training_args = TrainingArguments(
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    gradient_checkpointing=False,
    gradient_accumulation_steps=16, # set it so the total batch size is 1M tokens
    load_best_model_at_end=False,
    num_train_epochs=1,
    eval_strategy="steps",
    learning_rate=4e-4,
    fp16=not torch.cuda.is_bf16_supported(),
    bf16=torch.cuda.is_bf16_supported(),
    bf16_full_eval=torch.cuda.is_bf16_supported(),
    fp16_full_eval=not torch.cuda.is_bf16_supported(),
    logging_steps=10,
    adam_beta1=0.9,
    adam_beta2=0.95,
    adam_epsilon=1e-7,
    optim="adamw_torch", 
    save_total_limit=4000,
    eval_steps=40,
    save_steps=500,
    save_strategy="steps",
    weight_decay=0.1,
    max_grad_norm=1.0,
    seed=seed,
    lr_scheduler_type="cosine_with_min_lr",
    warmup_ratio=0.01,
    lr_scheduler_kwargs={"min_lr_rate": 0.1},
    push_to_hub=True,
    hub_private_repo=True,
    output_dir=f"mamba2training")

trainer = Trainer( model=model, args=training_args, train_dataset=dataset, eval_dataset=eval_dataset, data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False), )



4. Train using "accelerate launch --mixed_precision bf16 train.py" (if you don't use it, it will crash with "Nonetype object is not a mapping" error, it's related to mamba2 and triton not supporting multi threading I believe)

5. After around 100M tokens of training (100 steps), grad_norm will reach billions and eventually infinity (while loss starts rising) 

### Expected behavior

I believe that the script shouldn't require using accelerate launch and also considering the model is pretty small, there shouldn't be any problem with exploding gradients, the higher lr the earlier the problem appears, but from my observation it appears even when using lr as low as just 7e-5
DanFosing commented 1 month ago

I experimented with my config a bit and it seems like setting head_dim to 64 and increasing num_heads to 32, while also setting n_groups to 1 may be fixing the issue, currently managed to train up to 200M tokens without any problems (and the loss is a bit higher (by 0.1) than mamba-1 currently but there are no grad norm issues), also my friend managed to fix it by using slightly different configuration (head dim set to 128, num heads set to 16), also by setting n_groups to 1, is this behavior expected and n_groups should be set higher than 1 only for bigger models, or is there an actual problem with it?

DanFosing commented 1 month ago

700M tokens in and it seems like indeed setting n_groups to 1 fixes the issue, the loss is still worse by 1.8%, but I guess it may be normal

vasqu commented 1 month ago

I think ngroups was meant to be used based on the number of gpus during training, see https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/modules/mamba2.py#L82

Mistral very likely used a multi gpu setting in their case. Otherwise, the original mamba2 models all use ngroups=1 (from 130m up to 2.7b). I'm a bit surprised that it causes exploding gradients though, would have to dig deeper into why exactly. Does it use the fast or slow path in your case?

DanFosing commented 1 month ago

Maybe it would work if I set n_groups to the number of gpus, I haven't tried that actually. How can I check which path it uses?

vasqu commented 1 month ago

You would only need to care about it being a mutltiple of the number of gpus. But there is also a difference in the current implementation of transformers i would guess that each gpu gets ngroup = config.ngroup which means that your effective ngroups = config.ngroups * num_gpus (but im not too familiar with how transformers handles mutli gpu).

In the original repo you would distribute config.ngroups across the gpus so that the effective ngroups remains the same as in the config.

In short what i want to say: Your effective ngroups is very high with 8*2=16. Have you tried ngroups=4 to achieve an effective ngroups of 8? Keep in mind that if you load in multi gpu settings that the ngroups config has to be adjusted (e.g. single gpu would need ngroups=8 again).

You would get a warning during your run if the slow path was to be run reminding to install the original mamba_ssm package. If there was no such warning, you were probably in the fast path.

DanFosing commented 1 month ago

I tried using n_groups of 4, I tried running the training on only one gpu with n_groups=4, and the gradients still exploded. And yeah I'm using the fast path. Also what's really interesting is that I'm getting a warning that I'm using ddp_find_unused_parameters while it didn't fins any unused parameters, but I suppose it's just a false positive.

vasqu commented 1 month ago

I guess there is not much to do since it's still active research so some settings may be undesired just as you have discovered (ngroups). Maybe the original authors can answer you on the role of ngroups in more detail.

Can't help with the ddp warning as that's pytorch territory ^^ but it does have some buggy stuff from time to time.

vasqu commented 1 month ago

Also, have you tried clipping gradients? I think the original training settings in mamba(2) employed it so that might stabilize the training in general even with higher ngroups (but just a guess).

Edit: nvm seems to be default in trainer to clip gradients..

DanFosing commented 1 month ago

I'm not sure if that's what you mean, but as you can see in code snippets I uploaded, I'm using max_grad_norm=1.0, but it doesn't really help as after some time grad_norm is infinity so even while it's clipped it works worse as grad_norm is always at its max value

DanFosing commented 1 month ago

This is how grad_norm looks like in my hybrid architecture after setting n_groups to 1 (purple), and light blue one is exactly the same arch but with mamba-1 (ofc I changed initialization so it works with mamba-2 well) IMG_20240810_060702 It's not too stable but it doesn't go over gradient clipping value so it seems fine (+ it's getting more stable with time, each step is 1M tokens)

vasqu commented 1 month ago

Yup, seems like ngroups = 1 is viable in that case. Sorry about the confusion about gradient clipping, didn't know that it was default in the trainer.

DanFosing commented 1 month ago

And this is how it looked with n_groups set to 8 Screenshot_2024-08-10-06-16-53-750_com brave browser-edit

DanFosing commented 1 month ago

I'm wondering if it's somehow expected or if there is an error in the implementation

vasqu commented 1 month ago

Understandable, but i think this is a question that rather belongs to the original repo than transformers.

ArthurZucker commented 3 weeks ago

Indeed! And thanks @vasqu for take the lead and answering 🤗

github-actions[bot] commented 22 hours ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.