zyushun / Adam-mini

Code for Adam-mini: Use Fewer Learning Rates To Gain More https://arxiv.org/abs/2406.16793
257 stars 9 forks source link

How to Use AdamMini Optimizer with Weight-Tying on Models like qwen0.5b? #11

Closed relic-yuexi closed 1 month ago

relic-yuexi commented 1 month ago

Thank for your hard working. I would like to know how to use the AdamMini optimizer with weight-tying models such as qwen0.5b. Are there any specific configurations or adjustments needed? Any guidance or example code would be appreciated.

zyushun commented 1 month ago

Hi @relic-yuexi ! Thanks for your interests in Adam-mini!

We support weight-tying and you can try running Adam-mini without any adjustments. You might receive warnings by Adam-mini like "no output layers found", this warning can be ignored when the weight-tying is used.

relic-yuexi commented 1 month ago

Thank for your anser.

I use the code as follow:

class AdamMiniTrainer(SFTTrainer):
    def __init__(self, finetuning_args, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.finetuning_args = finetuning_args

    def create_optimizer(self):
        if self.optimizer is None:
            if self.finetuning_args.use_adammini:
                n_embd = self.model.config.hidden_size
                n_head = self.model.config.num_attention_heads
                n_query_groups = getattr(self.model.config, 'num_key_value_heads', n_head)

                self.optimizer = AdamMini(
                    model=self.model,
                    lr=self.args.learning_rate,
                    weight_decay=self.args.weight_decay,
                    betas=(self.args.adam_beta1, self.args.adam_beta2),
                    eps=self.args.adam_epsilon,
                    model_sharding=True,
                    n_feature=n_embd,
                    n_head=n_head,
                    n_kv_head=n_query_groups,
                )

        return super().create_optimizer()

Then i get a error as this

    state["m"] = state["m"].view(-1, dim)
RuntimeError: shape '[-1, 401408]' is invalid for input of size 57344

The model arch is:

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 896)
    (layers): ModuleList(
      (0-23): 24 x Qwen2DecoderLayer(
        (self_attn): Qwen2FlashAttention2(
          (q_proj): Linear(in_features=896, out_features=896, bias=True)
          (k_proj): Linear(in_features=896, out_features=128, bias=True)
          (v_proj): Linear(in_features=896, out_features=128, bias=True)
          (o_proj): Linear(in_features=896, out_features=896, bias=False)
          (rotary_emb): Qwen2RotaryEmbedding()
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=896, out_features=4864, bias=False)
          (up_proj): Linear(in_features=896, out_features=4864, bias=False)
          (down_proj): Linear(in_features=4864, out_features=896, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm()
        (post_attention_layernorm): Qwen2RMSNorm()
      )
    )
    (norm): Qwen2RMSNorm()
  )
  (lm_head): Linear(in_features=896, out_features=151936, bias=False)
)
model.config.hidden_size
896
model.config.num_attention_heads
14
model.config.num_key_value_heads
2
zyushun commented 1 month ago

Hi @relic-yuexi !

Perhaps you give Adam-mini the wrong configuration for your model. You need to tell Adam-mini the n_feature, n_head, and n_kv_head in your Qwen model, Here is a refined example (sorry for any confusion caused by our previous example in readme.md)

import Adam_mini

optimizer = Adam_mini(
        model = model, 
        lr = lr, 
        betas = (beta1,beta2), 
        eps = eps,
        weight_decay = weight_decay,
        model_sharding = True,
        n_feature = dim,
        n_head = n_head,
        n_kv_head = n_kv_head
    )

model_sharding: set to True if you are using model parallelism with more than 1 GPU, including FSDP and zero_1,2,3 in Deepspeed. Set to False if you are using DDP or single-GPU training.

n_feature: dimension for hidden feature.

n_head: number of attention heads.

n_kv_head: number of head for Key and Value. Or equivalently, number of query groups in Group query Attention. Also known as "n_query_groups". If is None, it will be the same value as n_head.

zyushun commented 1 month ago

Hi @relic-yuexi !

Just noticed your configuration. In your case, these numbers should be fed to Adam-mini

n_feature = 896
n_head = 14
n_kv_head = 2

Does the error still occur after passing these info to Adam-mini?

relic-yuexi commented 1 month ago

n_embd: 896, n_head: 14, n_query_groups: 2

yes, i get same error

    state["m"] = state["m"].view(-1, dim)
RuntimeError: shape '[-1, 401408]' is invalid for input of size 57344
zyushun commented 1 month ago

Hi @relic-yuexi .

Please try again with the latest version of Adam-mini.py and call as follows:

optimizer = Adam_mini(
        model = model, 
        lr = lr, 
        betas = (beta1,beta2), 
        eps = eps,
        weight_decay = weight_decay,
        model_sharding = True,
        n_feature = 896,
        n_head = 14,
        n_kv_head = 2
    )

I think it would be fine then.

relic-yuexi commented 1 month ago

same error here. It's weird. Phi3ForCausalLM can work.

    state["m"] = state["m"].view(-1, dim)
RuntimeError: shape '[-1, 401408]' is invalid for input of size 57344
Adam-mini found 1 embedding layers, 0 output layers, 24 Querys, 24 Keys.
=====>>> Warning by Adam-mini: No output layer found.  If you are training Transformers (without weight-tying), please check the name of your output layer and manually add them to 'self.embd_blocks' of Adam-mini. Please ignore this warning if you are using weight-tying.
n_embd: 896, n_head: 14, n_kv_head: 2Adam-mini found 1 embedding layers, 0 output layers, 24 Querys, 24 Keys.
=====>>> Warning by Adam-mini: No output layer found.  If you are training Transformers (without weight-tying), please check the name of your output layer and manually add them to 'self.embd_blocks' of Adam-mini. Please ignore this warning if you are using weight-tying.
n_embd: 896, n_head: 14, n_kv_head: 2
zyushun commented 1 month ago

Hi @relic-yuexi ,

We have updated another version of Adam-mini.py . Please try again with this file.

relic-yuexi commented 1 month ago

rtx 4090 two gpus and use deepzero3. It seems will cost more 3-4G for the optimizer.

image image image
zyushun commented 1 month ago

Hi @relic-yuexi ! It is good to know that you can set up the runs now.

It is unexpected that Adam-mini requires more memory. Some immediate questions related to your figures:

  1. Can you share what is "pageadam"?
  2. In the figure, why are there 4 gpus in total?

Best, Yushun

relic-yuexi commented 1 month ago

Other GPU are used by others 😂

image

Here you can see:

https://github.com/huggingface/transformers/blob/aec1ca3a588bc6c65f7886e3d3eaa74901a6356f/src/transformers/training_args.py#L162

And there is how to use them.

https://github.com/huggingface/transformers/blob/aec1ca3a588bc6c65f7886e3d3eaa74901a6356f/src/transformers/trainer.py#L1151

relic-yuexi commented 1 month ago

I test fuse adam. Maybe you can reproduce the result and compare it in gpt2 or others.

image image
zyushun commented 1 month ago

Hi @relic-yuexi thanks for the update! It is good to see that Adam-mini reduces memory over AdamW (torch fused) in your case

We took a look at PageAdam and find that it would automatically do CPU offload. This might be the reason why you did not observe GPU memory save over PageAdam. Note that CPU offload would incur higher latency.

309781720772541_ pic

relic-yuexi commented 1 month ago

Maybe not. PageAdam seems quick than Adan-mini. You can reproduce the result in qwen-0.5b, seq_len=4096, log by wandb.

image
zyushun commented 1 month ago

Hi @relic-yuexi , sorry for the late response.

One possible reason is that PageAdam is highly optimized in cuda kernel with fused version, which provides speed-up in wall-clock time. In contrast, Adam-mini is merely under plain implementation. We will try to combine Adam-mini with these speed-up techniques in the near future.