zyushun / Adam-mini

Code for Adam-mini: Use Fewer Learning Rates To Gain More https://arxiv.org/abs/2406.16793
323 stars 10 forks source link

Qwen2-0.5B cannot be Adam-mini-optimized in 4 shards (Deepspeed Zero-3) #25

Open xiningnlp opened 2 months ago

xiningnlp commented 2 months ago

Hi all,

I found that using Adam-mini 1.0.1 cannot run in 4 shards, it would threw the exception related to Tensor reshaping:

  File "/opt/conda/lib/python3.10/site-packages/adam_mini/adam_mini.py", line 175, in step
  File "/data/suanfa_lvm_data/xining/git/LLaMA-Factory/src/train.py", line 19, in main
(---> line 18)    m = torch.zeros_like(p, dtype=torch.float32)
(---> line 19)    state["m"] = m.view(-1, head_numel)
RuntimeError: shape '[-1, 57344]' is invalid for input of size 200704

Having debug the issue, I found Adam-mini failed to calculate the "m" value for model.layers.0.self_attn.q_proj.weight

To understand the above exception, I pasted the configs of Qwen2-0.5B as follows:

{
  "architectures": [
    "Qwen2ForCausalLM"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 151643,
  "eos_token_id": 151645,
  "hidden_act": "silu",
  "hidden_size": 896,
  "initializer_range": 0.02,
  "intermediate_size": 4864,
  "max_position_embeddings": 32768,
  "max_window_layers": 21,
  "model_type": "qwen2",
  "num_attention_heads": 14,
  "num_hidden_layers": 24,
  "num_key_value_heads": 2,
  "rms_norm_eps": 1e-06,
  "rope_theta": 1000000.0,
  "sliding_window": 32768,
  "tie_word_embeddings": true,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.40.1",
  "use_cache": true,
  "use_sliding_window": false,
  "vocab_size": 151936
}
  1. head_numel = hidden_size * hidden_size // num_attention_heads = 57344 which is expected!
  2. p is a 1-dim tensor with length 200074 (= hidden_size * hidden_size / shard_num) which is also expected in deepspeed Zero-3 and 4 cards settings
  3. However, 57344 cannot be divided by 57344, thus the reshape in line 19 is not allowed.

Does the above observation indicate whether Adam-mini can be used in Deepspeed environment depends on the shard number and the hidden states dim?

zyushun commented 2 months ago

Hi @xiningnlp . Thanks for the question!

The reason for the error is that: Currently, the implementation of Adam-mini only supports the case where num_attention_heads / num_gpu is an integer. In your case on Qwen 0.5B, num_attention_heads / num_gpu = 14 / 4 = 3.5 is not an integer, so it causes the error.

Thanks for mentioning this We will try to support more flexible choice of num_attention_heads in the future. For now, you can try the following simple tweaks.

  1. try num_gpu = 2 or 7. In these cases, num_attention_heads / num_gpu will be an integer.
  2. still use num_gpu = 4 but try to train any other scaled, such as 1.5B, 7B, etc. All the rest of QWen have num_attention_heads to be the multiples of 4, and thus shall not raise error. Actually, Qwen 0.5B is the only exception that num_attention_heads is not the multiples of 4.
  3. If you do not intend to change anything on num_gpu or architectures. You can add the following line after creating the optimizer.
optimizer.wqk_names = {}

This will force Adam-mini to treat Q and K similarly as regular MLP layers, and thus will not involve any head-related partition operations. For SFT, such changes usually will not cause performance degradation. Yet, please tell us if you observe any.

Authors

xiningnlp commented 2 months ago

Hi @xiningnlp . Thanks for the question!

The reason for the error is that: Currently, the implementation of Adam-mini only supports the case where num_attention_heads / num_gpu is an integer. In your case on Qwen 0.5B, num_attention_heads / num_gpu = 14 / 4 = 3.5 is not an integer, so it causes the error.

Thanks for mentioning this We will try to support more flexible choice of num_attention_heads in the future. For now, you can try the following simple tweaks.

  1. try num_gpu = 2 or 7. In these cases, num_attention_heads / num_gpu will be an integer.
  2. still use num_gpu = 4 but try to train any other scaled, such as 1.5B, 7B, etc. All the rest of QWen have num_attention_heads to be the multiples of 4, and thus shall not raise error. Actually, Qwen 0.5B is the only exception that num_attention_heads is not the multiples of 4.
  3. If you do not intend to change anything on num_gpu or architectures. You can add the following line after creating the optimizer.
optimizer.wqk_names = {}

This will force Adam-mini to treat Q and K similarly as regular MLP layers, and thus will not involve any head-related partition operations. For SFT, such changes usually will not cause performance degradation. Yet, please tell us if you observe any.

Authors

@zyushun Thanks for your prompt reply. But on the other hand, theoretically, there is no need to use Adam-mini for a 0.5B SLM since it won't consume too much GPU MEM to SFT (full parameters) a 0.5B SLM, right? and practically, I received no gain in my experiment, was this observation expected.

zyushun commented 2 months ago

Hi @xiningnlp . Yes, you are right. For 0.5B, optimizer memory is not a heavy overhead.

For Adam-mini, it does not save much memory over Adam for 0.5B models. This is because: Adam-mini still uses AdamW for the embedding layer, and embedding layer takes a large proportion of total params for 0.5B models. So it is as expected if you did not observe much memory cut-down by Adam-mini for 0.5B models.

Nevertheless, things will be different when the model size increases to > 1B. In these cases, the proportion of embedding layer shrinks to <10%, and the memory gain of Adam-mini starts to be significant (you will see ~50% cut down over Adam).