huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.35k stars 26.86k forks source link

Mistral loss instability #26498

Closed teknium1 closed 10 months ago

teknium1 commented 1 year ago

System Info

Hello, I've been working with dhokas who finetuned Mistral's official instruct model. I have been trying to finetune mistral with several datasets over dozens of ablations. There is very insane loss instability training this model with transformers that never seems to appear with his training runs which do not use hf trainer.

I am opening this so we can get to the bottom of this. Here are some of my runs using axolotl with some datasets.

With hermes 2.0 dataset (unpublished): https://wandb.ai/teknium1/hermes2.0-mistral-7b?workspace=user-teknium1

With Teknium/GPT4-LLM-CLEANED dataset https://wandb.ai/teknium1/gpt4llm-mistral-7b

With a 5-sequences run to ensure loss goes to 0 (that memorization is occurring): https://wandb.ai/teknium1/5seq-mistral-7b?workspace=user-teknium1

With OpenHermes dataset teknium1/openhermes: https://wandb.ai/teknium1/hermes-mistral-7b

as can be seen, these loss charts with all these ablations are unreliable, and generally produce bad results no matter what hyperparams are changed.

Mistral dev who worked with me, he trained mistral with gpt4llm cleaned and got this result: image

@younesbelkada @muellerz

Who can help?

No response

Information

Tasks

Reproduction

Train Mistral on any of the above datasets with Mistral's own finetune hyperparams as reported in mistral's discord and see the loss fail to work out

Expected behavior

A smooth or downward trajectory for the loss.

younesbelkada commented 1 year ago

Thanks for the heads up I have realised that this was wrong few days ago and I have closed the PR accordingly as you can see from: https://huggingface.co/mistralai/Mistral-7B-v0.1/discussions/35

vince62s commented 1 year ago

Did you see my other comment above wrt inference and rotating cache?

younesbelkada commented 1 year ago

I have the impression that in HF you implemented only the sliding windows attention by playing only on the attention mask and ONLY at training time, which means that at inference, the full length is taken into account, am I correct ?

If you use the vanilla HF attention yes, that is the case we did not implemented the rotating buffer cache mechanism as it requires an important refactor

However we tried to mimic the rotating buffer caching mechanism by constraining it only in the case where padding_side=left for FA-2 models by shifting the cache and slicing out the previous tokens when generating the next token. See my benchmarks here for more details: https://github.com/huggingface/transformers/pull/26464#issuecomment-1743273513

vince62s commented 1 year ago

ok I get it, here: https://github.com/huggingface/transformers/pull/26464/files#diff-fa1653b47666859672060712644a8c40b2e61eb1b79c06a21f9b94569217ed43R372-R393 Anyway it requires some hardware to support seqlen > 4096 ....

younesbelkada commented 1 year ago

yes exactly

Anyway it requires some hardware to support seqlen > 4096 ....

No you can scale to very large sequence length as the cache will be always having 4096 tokens, similarly as the rotating buffer cache from original mistral repository.

Per my understanding (cc @timlacroix please correct me if I am wrong) since we always use absolute positional embedding the model is able to keep the whole context even if we go beyond 4096 tokens. In case one feeds to the model a super large context (>4096) directly on the first iteration, you will indeed need to enough compute but since the FA module will use sliding window attention, it should be quite memory efficient. Slicing the cache afterwards is not a problem since the model has already computed attention scores based on the entire context on the first iteration so the information is not lost. In case of batched generation it is slightly more complex since we don't follow the exact same procedure as mistral's rotating buffer cache, we slice out the first tokens of the cache after the first iteration. But in case of BS=1 you should get pretty decent performance, if you have a hardware that supports FlashAttention 2 you can try to generate up to very large number of tokens without any major issue I believe

vince62s commented 1 year ago

hmm the cache size is not the only limiting factor. You still need to forward the full sequence to the model, and the flash2 still happens with the full length even if the mechanism makes it linear to length (and not quadratic)

younesbelkada commented 1 year ago

but that's the case in any case right? for the first forward is you pass a large context you'll need to compute the attention scores on all tokens.

teknium1 commented 1 year ago

@younesbelkada @bdytx5 @vince62s @arthurmensch

Okay update on the issue.

image

The above image is testing with deepspeed zero 2 vs FSDP. Zero 2 is the more stable trajectory run. Same hyperparams on all else. I feel like I tested with zero3 in the past, and found same as FSDP run, a U shaped pattern, but I am not sure atm.

At the moment I dont know if it is being caused by axolotl's interactions with FSDP, or if it is something in transformers/accelerate/who knows what. But this seems like an important development in figuring out whats going on, not sure how much you guys can look into it, but figured I'd place the info here in case it isn't axolotl's code.

edit: nevermind... image

however, it still looks far better than my loss curves on runs with much lower LR's than this one above (it has 2.5e-5) image

teknium1 commented 1 year ago

Ok I did a new longer run with deepspeed zero 2 vs fsdp all else same: image

Something about fsdp is making it converge slower (and technically, loss is not moving downward at all, very very very slightly upward) - with LR 4e-6

teknium1 commented 1 year ago

Zero 3 and Zero 2 seem fine, just not FSDP. I will reference the issue in axolotl and pytorch repos

image

nps798 commented 1 year ago

image

for me, using transformer, trainer and custom dataset, batch size of 2, accumulation of 6, training loss drop to 0.0 after certain points. Eval loss become NaN I am using torch_dtype of torch.float16

I ve seen someone saying change float16 to bfloat16 ?

younesbelkada commented 1 year ago

hi @nps798 Yes I think using bfloat16 is preferrable to be on the safe zone. Also something strange that I have noticed is that in case you use padding, make sure to set padding_side="right" in case you train with padding tokens: https://gist.github.com/younesbelkada/9f7f75c94bdc1981c8ca5cc937d4a4da?permalink_comment_id=4636728#gistcomment-4636728

nps798 commented 1 year ago

hi @nps798 Yes I think using bfloat16 is preferrable to be on the safe zone. Also something strange that I have noticed is that in case you use padding, make sure to set padding_side="right" in case you train with padding tokens: https://gist.github.com/younesbelkada/9f7f75c94bdc1981c8ca5cc937d4a4da?permalink_comment_id=4636728#gistcomment-4636728

Thanks for your reply. I'll give it a try soon.

BTW, I have just encountered another issue with my previous float16 and padding left setting, qlora I ve checked my input batch data near around those batches (yeah I print out all batch on each step), nothing weird or special. I check all the model's parameters with the following code

for name, param in model.named_parameters():
        if torch.isnan(param).any():
            print(f'NaN value detected in model weights: {name}')
        if torch.isinf(param).any():
            print(f'Infinity value detected in model weights: {name}')

Nothing was printed.

So... Correct me if I am wrong, the folloinwg NaN is not coming from problematic dataset. It is related to some weights of the model being too small or too big, and the NaN will be produced by any dataset. And are unable to detect beforehand ?

input[0] has nans output has nans

Detected inf/nan during batch_number=54681 Last 21 forward frames: abs min abs max metadata base_model.model.model.layers.30.mlp.gate_proj Linear4bit 0.00e+00 2.55e+02 weight 0.00e+00 1.85e+02 input[0] 5.96e-08 3.02e+01 output base_model.model.model.layers.30.mlp.act_fn SiLUActivation 5.96e-08 3.02e+01 input[0] 0.00e+00 2.39e+01 output base_model.model.model.layers.30.mlp.up_proj Linear4bit 0.00e+00 2.55e+02 weight 0.00e+00 1.85e+02 input[0] 5.96e-08 2.36e+01 output base_model.model.model.layers.30.mlp.down_proj Linear4bit 0.00e+00 2.55e+02 weight 0.00e+00 3.70e+02 input[0] 0.00e+00 1.38e+02 output base_model.model.model.layers.30.mlp MistralMLP 0.00e+00 1.85e+02 input[0] 0.00e+00 1.38e+02 output base_model.model.model.layers.30 MistralDecoderLayer 0.00e+00 3.05e+02 input[0] 0.00e+00 1.67e+02 output[0] 0.00e+00 1.68e+01 output[1][0] 0.00e+00 8.28e+00 output[1][1] base_model.model.model.layers.31.input_layernorm MistralRMSNorm 8.36e-01 8.75e+00 weight 0.00e+00 1.67e+02 input[0] 0.00e+00 9.58e+01 output base_model.model.model.layers.31.self_attn.q_proj.lora_dropout.default Dropout 0.00e+00 9.58e+01 input[0] 0.00e+00 1.01e+02 output base_model.model.model.layers.31.self_attn.q_proj.lora_A.default Linear 9.78e-08 1.07e-01 weight 0.00e+00 1.01e+02 input[0] 2.04e-03 8.38e+01 output base_model.model.model.layers.31.self_attn.q_proj.lora_B.default Linear 1.98e-07 8.64e-02 weight 2.04e-03 8.38e+01 input[0] 2.06e-07 2.49e+01 output base_model.model.model.layers.31.self_attn.q_proj Linear4bit 0.00e+00 2.55e+02 weight 0.00e+00 9.58e+01 input[0] 0.00e+00 2.62e+01 output base_model.model.model.layers.31.self_attn.k_proj.lora_dropout.default Dropout 0.00e+00 9.58e+01 input[0] 0.00e+00 1.01e+02 output base_model.model.model.layers.31.self_attn.k_proj.lora_A.default Linear 2.39e-07 7.29e-02 weight 0.00e+00 1.01e+02 input[0] 6.44e-05 5.60e+01 output base_model.model.model.layers.31.self_attn.k_proj.lora_B.default Linear 3.00e-07 6.73e-02 weight 6.44e-05 5.60e+01 input[0] 4.96e-07 1.24e+01 output base_model.model.model.layers.31.self_attn.k_proj Linear4bit 0.00e+00 2.55e+02 weight 0.00e+00 9.58e+01 input[0] 0.00e+00 1.85e+01 output base_model.model.model.layers.31.self_attn.v_proj.lora_dropout.default Dropout 0.00e+00 9.58e+01 input[0] 0.00e+00 1.01e+02 output base_model.model.model.layers.31.self_attn.v_proj.lora_A.default Linear 1.05e-07 1.07e-01 weight 0.00e+00 1.01e+02 input[0] 1.04e-03 5.54e+01 output base_model.model.model.layers.31.self_attn.v_proj.lora_B.default Linear 7.20e-07 3.79e-02 weight 1.04e-03 5.54e+01 input[0] 7.59e-07 6.53e+00 output base_model.model.model.layers.31.self_attn.v_proj Linear4bit 0.00e+00 2.55e+02 weight 0.00e+00 9.58e+01 input[0] 0.00e+00 8.99e+00 output base_model.model.model.layers.31.self_attn.rotary_emb MistralRotaryEmbedding 0.00e+00 8.99e+00 input[0] 5.15e-05 1.00e+00 output[0] 0.00e+00 1.00e+00 output[1] base_model.model.model.layers.31.self_attn.o_proj Linear4bit 0.00e+00 2.55e+02 weight nan nan input[0] nan nan output

muximus3 commented 1 year ago

image My training loss is behaving strangely as it suddenly explodes at different positions during each training. I attempted to resolve this issue by following the instructions in mistral-7b-instruct and setting padding_side to "right", with pad_token being set as eos_token, but it didn't solve the problem. I use deepspeed stage3 and bfloat16.

nps798 commented 1 year ago

@younesbelkada thank you I set the torch dtype to bf16 (while remaining the padding as left

successfully qlora fine tuning with 5 epoch without exploding loss or zero loss.

will keep experiment some other combinations of parameters

teknium1 commented 1 year ago

I can confirm at least 2 other people have this issue with FSDP now. I still see loss go up after per-epoch drops in my training runs with deepspeed as well however, leaving me concerned but in a better state than previously.. which was always U shaped loss curves image image

younesbelkada commented 11 months ago

Hi everyone Thanks a lot for the deep investigation, recently @pacman100 managed to successfully fine-tune llama (from what I have understood the issue is quite agnostic to the architecture) using FSDP and shared some insights here: https://github.com/huggingface/accelerate/issues/2127#issuecomment-1802641032 It seems the solution is to not load the model in bf16 and instead enable mixed precision training through TrainingArguments by passing bf=16 cc @pacman100 in case I missed something

jph00 commented 11 months ago

Thanks a lot for the deep investigation, recently @pacman100 managed to successfully fine-tune llama (from what I have understood the issue is quite agnostic to the architecture) using FSDP and shared some insights here

I think this was a misunderstanding, and actually it's not successfully training. However @tmabraham did show a workaround in that thread.

pacman100 commented 11 months ago

Hello,

I ran the below experiment to see the fine-tuning using FSDP and Mistral was as expected. Below are the results:

  1. Codebase: https://github.com/pacman100/DHS-LLM-Workshop/tree/main/chat_assistant/training

  2. Dataset: smangrul/chat-instruct-mixer

  3. Model: mistralai/Mistral-7B-v0.1

  4. Accelerate config after running accelerate config --config_file fsdp_config.yaml and answering the questionnaire:

    compute_environment: LOCAL_MACHINE
    debug: false
    distributed_type: FSDP
    downcast_bf16: 'no'
    fsdp_config:
    fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
    fsdp_backward_prefetch_policy: BACKWARD_PRE
    fsdp_cpu_ram_efficient_loading: true
    fsdp_forward_prefetch: false
    fsdp_offload_params: false
    fsdp_sharding_strategy: 1
    fsdp_state_dict_type: SHARDED_STATE_DICT
    fsdp_sync_module_states: true
    fsdp_use_orig_params: true
    machine_rank: 0
    main_training_function: main
    mixed_precision: bf16
    num_machines: 1
    num_processes: 8
    rdzv_backend: static
    same_network: true
    tpu_env: []
    tpu_use_cluster: false
    tpu_use_sudo: false
    use_cpu: false
  5. Command:

    accelerate launch \
    --config_file configs/fsdp_config.yaml \
    train.py \
    --model_name "mistralai/Mistral-7B-v0.1" \
    --dataset_name "smangrul/chat-instruct-mixer" \
    --max_seq_len 4096 \
    --max_steps 5000 \
    --logging_steps 25 \
    --eval_steps 1000 \
    --save_steps 1000 \
    --bf16 True \
    --packing True \
    --output_dir "/fsx/sourab/experiments/full-finetune-mistral-7b-fsdp-chat-asst" \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 1 \
    --dataset_text_field "content" \
    --use_gradient_checkpointing False \
    --learning_rate 5e-6  \
    --lr_scheduler_type "cosine" \
    --weight_decay 0.01 \
    --warmup_ratio 0.03 \
    --max_grad_norm 1.0 \
    --use_flash_attn True
  6. Training plots at the end of 1000 steps:

    Screenshot 2023-11-16 at 1 59 09 PM
  7. Observations: a. Loss is going down as expected and it is successfully training. b. Sensitivity to learning rate: When I used learning rates of 5e-5 or 2e-5, the training was not converging properly. 5e-6 worked best for my dataset. So, when fully fine-tuning, hyperparameter tuning is important. c. seq-length 4096 with batch size 8 (per GPU 1 and gradient accumulation steps 1) has lower loss when compared to seq-length 2048 with batch size 16 (per GPU 1 and gradient accumulation steps 2).

  8. Library versions:

    • Output of transformers-cli env:
    • transformers version: 4.35.2
    • Platform: Linux-5.15.0-1023-aws-x86_64-with-glibc2.31
    • Python version: 3.11.4
    • Huggingface_hub version: 0.16.4
    • Safetensors version: 0.3.2
    • Accelerate version: 0.24.1
    • Accelerate config: not found
    • PyTorch version (GPU?): 2.1.0.dev20230809 (True)
    • Tensorflow version (GPU?): not installed (NA)
    • Flax version (CPU?/GPU?/TPU?): not installed (NA)
    • Jax version: not installed
    • JaxLib version: not installed
    • Using GPU in script?:
    • Using distributed or parallel set-up in script?:
    • Output of accelerate env:
    • Accelerate version: 0.24.1
    • Platform: Linux-5.15.0-1023-aws-x86_64-with-glibc2.31
    • Python version: 3.11.4
    • Numpy version: 1.24.3
    • PyTorch version (GPU?): 2.1.0.dev20230809 (True)
    • PyTorch XPU available: False
    • PyTorch NPU available: False
    • System RAM: 1121.82 GB
    • GPU type: NVIDIA A100-SXM4-80GB
    • Accelerate default config: Not found
    • flash-attn: 2.3.3
github-actions[bot] commented 10 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

teknium1 commented 10 months ago

Is this solved due to the previous mention?