zyushun / Adam-mini

Code for Adam-mini: Use Fewer Learning Rates To Gain More https://arxiv.org/abs/2406.16793
257 stars 9 forks source link

works well on smaller models, updates for torchtitan and 8B size #7

Open lessw2020 opened 1 month ago

lessw2020 commented 1 month ago

Hi there, First, really nice job on developing adam-mini! It's a really refreshing approach for adam. I did some initial testing by integrating adam_mini into torchtitan, and ran it with varying sizes of hand-made llama3 models from 10M to 8B (official size).

What I found is that adam-mini is very competive if not better than AdamW at small model sizes, but declines as the model gets larger.
At 8B, then it's spiking all over the place during the warmup lr phase, and then kicks in with a more monotonic loss decline, once lr starts it's linear decline...but at that point it is so far behind it cannot compete.

I'm wondering if there is any thought to further sub-dividing the hessian sub-groups as model size grows, or do you have any other input re: helping adam-mini be effective at say 8B scale?

I did see a large amount of memory savings at 8B (~9.5% total reduction per gpu) so that is promising, but the training accuracy as noted is not competitive at that scale.

If you are interested in running to experiment, I have posted the code (had to make a couple integration tweaks) here: https://github.com/lessw2020/torchtitan_oss/tree/adam_mini

zyushun commented 1 month ago

Thanks for your interest and thanks for the feedback!

I noticed that you skip the "all_reduce" operation in our original implementation: line 222 of Adam_mini.py. Please try to put "all_reduce" back and see if it helps. Another simple tweak is to try some smaller lr and see if it is better.

We are currently reproducing your results. If you have any further findings, please feel free to update them here.

zyushun commented 1 month ago

Hi, I just found that our implementation of "all_reduce" is not compatible with your code. We are working on implementing a new "all_reduce" now

zyushun commented 1 month ago

Hi, further update:

I just found that you are actually using (1) default pytorch partition (2) empty embd_blocks (3) no all_reduce. All these will cause training instability and should be modified, as we discussed in the paper.

Why do (1) and (2) happen? this is actually due to our naive way of implementation (sorry for the trouble). In particular: we currently judge "whether the current block is embedding or Q or K " by the name of that block (i.e., "attn.wq.weight"). The name in your code is different from ours (e.g.., you are using "atttention.wq.weight" instead of "attn.wq.weight" ), so all our if-else condition fails.

Sorry for the inconvenience. We will update together with the compatible all_reduce operator.

lessw2020 commented 1 month ago

Hi @zyushun! Thanks very much for investigating this and very glad to hear that you've isolated things re: titan and llama3. Will be happy to run things if you have an update both for 8B and then 70B.

Also, I should add that the reason for the changes in the all_reduce were b/c titan uses DTensor, and so the tmp_lr is a dtensor and thsu cannot be used in a normal all_reduce. Instead, I had to do force it to do an all_reduce via the line I added. I did also test a workaround which I think I left in the commments of tmp_lr.to_local() and then regular all_reduce and then back to DTensor but I found the precison was better to simply do the direct all_reduce via DTensor. I can check on how to avoid that from becoming a DTensor in the first place as well, though I think it works fine as DTensor atm.

Lastly, the main repo for Titan is here (in case you want to make your own fork). Please ensure you run with latest PyTorch nightly though otherwise the apis are changing a fair amount and older PyTorch builds may break: https://github.com/pytorch/torchtitan

Please let me know what you find and also if you have further questions about titan/llama3!

zyushun commented 1 month ago

Hi @lessw2020!

I just updated a new version of Adam-mini that can work on your codebase. Please use this latest version it as follows

        optimizer = Adam_mini(model, lr=lr, 
        betas = (0.9,0.95), 
        weight_decay=0.1, 
        model_sharding=True,
        n_embd=4096,
        n_head=32,
        n_query_groups=4
        )

Using your codebase, I pre-trained Llama3-8B with 1000 steps and I got the following curves. Adam-mini performs almost the same as AdamW. (Remark: I tuned down the lr of Adam-mini to 1e-4, while the lr of AdamW is remained as 3e-4).

46461720264536_ pic

I use the same setting as you shared expect for the following (minor) changes:

  1. I use seq_len = 2048 since the original seq_len of 8192 caused OOM error on my (poor) machine. You can change it back on your machine

  2. used "c4_mini" as you attached in the code

  3. changed "fused_rmsnorm" to "rmsnorm". This is because the "fused_rmsnorm" is not comptatible with our all_reduce operation for vmean. If you insists on using "fused_rmsnorm", a simple tweak is to use AdamW for normalization layers as well. This only increases negligible memory and I guess it won't affect the performance. You can find these layers by parameter using keyword "norm" in their name, and put them in the if-else condition for embedding blocks.

Please do not hesitate to contact us for any further update! If it also works on your side, feel free to try it out on larger models such as 70B! Looking forward to hear from you soon.

Yushun

lessw2020 commented 1 month ago

Hi @zyushun, this is great news! Let me get things setup and will run and update here.

lessw2020 commented 1 month ago

Hi @zyushun - here are some initial results. It's looking way better with your changes. I think maybe the only delta will be slight lr adjustments. Changes: 1 - I used sequence of 8192 (vs you had 2048 as you clarified due to OOM) 2 - I used c4 dataset. The c4_mini will loop and reuse data for a 1K run on 8 gpus, so we don't recommend c4_mini except for short runs. C4 is huge so no data looping. 3 - I compared both at 1e-4 to start. (not shown yet, I also ran with adamw 3e-4 and it performed even better than adamw1e-4, so I have adam_mini running now also with 3e-4...will update soon, but it seems mini 3e-4 is now tracking almost exactly with adamw 1e-4). 4 - as you noted, I ran with rmsnorm to keep things matching.

Anyway, here's the results for 1e-4 for both:

Screenshot 2024-07-06 at 2 48 10 PM
lessw2020 commented 1 month ago

Hi @zyushun - congrats! With a slight bump in lr (3e-4 mini vs 1e-4 adamw) and mini shows very similar curves but with overall outperformance! This is imo a very big accomplishment as most optimizers can't do this (meet / exceed adamw) at 8B scale and esp not while reducing memory so significantly.
Results:

Screenshot 2024-07-06 at 3 48 31 PM

while doing this with a significant drop in overall gpu mem requirements:

Screenshot 2024-07-06 at 3 49 19 PM

I'm currently running mini with 4e-4 to see if that then maps to adamw 3e-4. (5e-4 was unstable).

Regardless, this is extremely promising here! A couple general questions:

a - is there any option to adjust mini to integrate with fused_rmsnorm? the fused boosts training by about 15% and reduces gpu memory as well, so it's an important part.

b - on Monday I'll setup runs on 70B so we can test it on even larger scale. Is there any adjustments you think I would need for 70B? (I guess I'm not clear on how the n_embd is being computed and thus not sure how to update it for running 70B).

lessw2020 commented 1 month ago

one more update - it seems 3e-4 is the optimal value for adamw.
I tried to match it with mini and 4e-4, as well as 5e-5 with longer warmup (with same warmup of 300 then 5e-5 is unstable) but for these mini ends up ~ .15 behind:

loss curve:

Screenshot 2024-07-06 at 6 13 20 PM

Is there anything else that could be adjusted here - it's very close, but I'm not sure what else might help here.

zyushun commented 1 month ago

Hi @lessw2020 ! Thanks so much for the detailed update! We are sincerely grateful for your support and all the great discussion!

I have carefully read all your updates, it seems that: on 8B model pre-training:

  1. when seq_len = 2048: Adamw(3e-4, best tuned) = Adam-mini (1e-4). Highly similar loss curve.

  2. when seq_len = 8192, Adamw(3e-4, best tuned) is a bit faster than Adam-mini(4e-4, best tuned). the gap is about 0.15 in train loss. The shape of the curve still looks highly similar.

This is quite interesting. I suspect there is something mysterious occurs in the long-context scenario. I will try to get 8 gpus today and try to reproduce and shrink this 0.15 gap on my side. I will update to you as soon as I get new findings.

Remark: there is one another possibility: this 0.15 gap will shrink as training goes on. This sometimes happens on some small-scaled task, too.

Quick answers to some of your other questions:

  1. n_embd: this is the hidden feature dimension of your LLM, also known as "hidden_size" in LLama3 series. Note that this is not the vocabulary size (sorry for the potential confusion here). For Llama3-8B, n_embd = 4096; for Llama3-70B n_embd = 8192

  2. fused_rmsnorm: We are working on it, but it seems not so trivially solvable. For now, I suggest to try to put all the RMSnorm into "embd_blocks" of Adam-mini. For blocks in "embd_blocks", AdamW will be used to update parameters so there is no all_reduce operation any more. This can be simply done with one-line of code change: change line 134 in Adam-mini.py to the following:

 if (
                            "embed_tokens" in name or "wte" in name or "lm_head" in name or "tok_embeddings" in name or "output.weight" in name or "norm" in name):
                        if p.grad is None:

This change will only incurr negligible memory increase (so slightly that i belive you will not even notice them). I haven't try this before but I guess it won't hurt the performance.

I will try to get 8 gpus today and reproduce your results on my side. Thanks again for your interests and support. Let us keep in touch.

lessw2020 commented 1 month ago

Hi @zyushun, Thanks for the info above!
I have another update that may help.

I switched to short context (2084) and reduced the run to 900 (with warmup = 300) so I could run more tests. I did a sweep of lr from 1e-4 to 5e-4 for mini and 3e-4 with adamW as the control.
I got different results from you in the respect that the slightly higher lr for mini still performed the best of the lr's for mini. 5e-5 became unstable. (edit - note that 5e-4 on a few updates was doing slightly better than adam 3e4 so if it had remained stable then likely would have performed like 3e4 mini vs 1e4 adamw on 8192 seq len in terms of slightly outperforming).

Here's the overview:

Screenshot 2024-07-06 at 9 04 45 PM

In this case mini with 4e-4 was the closest - I have highlighted the lowest point (best result) so you can see the absolute differences.

2 - Thanks for clarifying re: n_embd...I thought it referred to vocab since those are the 'embeddings'. This is really model_dim and I think it would be helpful to change the argument name to something like dim_model or similar to avoid confusion.

Anyway, at least with my tests both for short and long context I see a similar gap and with similar behaviour in that slightly higher than adam lr produces the results closest to adamw (with of course large memory savings).

I'll kick off a longer run now with 2048 for 2400 iters, so we can see if the small gap gets closed and please let me know if you have any additional updates!

zyushun commented 1 month ago

Hi @lessw2020 ! Thanks for the swift response!

When seq_len = 2048, it seems that my previous result (AdamW3e-4 = Adam-mini1e-4) is a bit different from yours (AdamW3e-4-yellow is slightly better than Adam-mini4e-4-gray, but they are close and the curves already seems partially overlapped). Why are our results (slightly) incoherent? Is it due to "c4_mini" v.s. "c4"? or different warm_up? I used warm_up = 40 (why 40? tbh I forgot why I chose this number..). I will try to figure out and try to reproduce your result today.

By comparing the curves for seq_len = 2048 and seq_len = 8192, it seems that Adam-mini is closer to AdamW in seq_len = 2048 than in seq_len = 8192. Huh, interesting long context!

Regarding "n_embd" we adopt this name from Tinyllama. Thanks for the suggestion and we will consider changing it.

I will keep updating.

lessw2020 commented 1 month ago

Hi @zyushun - a bit of progress. I ran 2400 iters 9800 warmup) and as part of that, the gap for the best point on both curves closed to .1 difference vs closer to .2 in the early iters. So I'm kicking off 5000 iter run now and also increasing batch size to 8 since it has room with 2048 seq len.

Here's the 2400 chart:

Screenshot 2024-07-06 at 11 01 34 PM

Re: your results- I would say the issue there is that using c4_mini means the data has looped and therefore it's less challenging as the data repeats. So using c4 is a better run from a data integrity aspect. Also re: warmup - I'm using much longer warmups (often 25% of the total iters) so maybe that means you spent longer on the downhill of the lr and so results converged better.

Hopefully my 5K run tonight and tomorrow morning will show if that is where the gap can be closed. Also, thanks for the tip on how to update for rmsnorm. I will add that in tomorrow but sticking with rmsnorm for now so we have consistent comparisons across all tests today.

zyushun commented 1 month ago

Hi @lessw2020 , update on seq_len = 2048

I run 10,000 steps with 2500 warm_up steps as you recommended. I use lr = 3e-4 for both Adam-mini and AdamW. I still use per-gpu bs = 2. I changed "c4_mini" to "c4" to avoid re-loop. Here is what I got.

image

Adam-mini is slightly slower in the first 2000 steps but then catches up and slightly outperforms AdamW in the remaining steps. As a bonus finding, AdamW suffers from a loss spike around 4000+ step, while Adam-mini does not. But we do not intend to claim any thing about it.

Hope you also got the similar results on your side :D

lessw2020 commented 1 month ago

Hi @zyushun - nice, thanks for the 10K data! I got similar results at 5K in that the early lead by AdamW gets closed and best min loss is only -0.04 diff. I ran with larger batch size (8):

Screenshot 2024-07-07 at 1 43 07 PM

And loss curves:

mini_5K_results

I think I will update mini to use fused_rmsnorm (thanks for the info on how to do this!) and then may do a run with 10K as well and then try mini on 70B tomorrow.

Regardless, it's clear mini is a strong candidate here - I want to see if we can get it added as an optimizer option for titan, and also depending on time, I would be interested to move it into a cuda kernel to further speed it up.
Let me update it for using fused though and kick off the 10K run as next steps.

lessw2020 commented 1 month ago

Hi @zyushun - I've updated with your changes to integrate fused rms_norm...I felt the number of checks was confusing though so have simplified it (and should be a tad faster) by using a set:

# specific layers, including fused_rmsnorm, are incompatible and are run using normal adamW updates
        self.exclude_layers = {
            "embed_tokens", "wte", "lm_head", "tok_embeddings", "output.weight", "norm"
        }

and then:

 # handle excluded layer types via adamw updates
                    if any(layer in name for layer in self.exclude_layers):

that is now allowing it to run with fused rmsnorm, which is adding about +5% MFU and a slight further drop in gpu memory... Anyway, running with the 10K run now.

One item if you have time - I went to test my fused rmsnorm changes with the debug_model.toml which is our goto for verifying changes before running larger but now I get a shape incompat error. I verified this is not related to the fused_rmsnorm changes. Could you take a look by switching to running the debug_model.toml in the run_llama_train.sh and see if you are able to understand what is going awry there?

zyushun commented 1 month ago

Hi @lessw2020 ! Thanks soooooo much for your great advice! We will modify them immediately.

I am also working on the debug_model.toml now.

Bad news: as for fused_norm, it seems that my naive solution does not work :(. When putting rmsnorm into "exclude_layers", the loss wont go down. I will try to work out an alternative solution.

zyushun commented 1 month ago

Hi @lessw2020 ! Regarding the debug_model.toml.

The incompat error is because: this debug model uses different n_heads, dim, and so on, while Adam-mini does not know this. This could be fixed as follows.

  1. Add " global model_config" at line 232 of train.py.
    global model_config
    model_config = models_config[model_name][job_config.model.flavor]
  1. Change the input of Adam-mini as follows.
optimizer = Adam_mini(model, lr=lr, 
            betas = (0.9,0.95), 
            weight_decay=0.1, 
            model_sharding=True,
            n_embd=model_config.dim,
            n_head=model_config.n_heads,
            n_query_groups=model_config.n_kv_heads,
            )

Best, Yushun

zyushun commented 1 month ago

Hi @lessw2020 , regarding fused_rmsnorm:

A strange thing happened. I did not change anything but the original Adam-mini seems to work with fused_rmsnorm. See the curve below. Note that I did NOT add "norm" to self.excluded_layers.

image

The log says we are indeed using fused_rmsnorm.

[rank0]:2024-07-08 13:40:10,845 - root - INFO - Building llama3 8B with ModelArgs(dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=1024, ffn_dim_multiplier=1.3, norm_eps=1e-05, rope_theta=500000, max_batch_size=32, max_seq_len=2048, depth_init=True, norm_type='fused_rmsnorm')

Could you try again on your side to see if it is also working now?

lessw2020 commented 1 month ago

Hi @zyushun, Thanks for the updates re: fused norm and the update for making the inputs directly updated. I was able to run a short experiment tonight and can confirm that fused_rmsnorm is working nicely with mini with no changes. This is very helpful as fused improves MFU and lowers total memory so it's a good combo! Did not have time yet to run a 70B run but still planning to do that next. Best regards, Less

zyushun commented 1 month ago

Hi @lessw2020 !

It is great to hear that things work well on your side, too. Thanks again for all the great discussion and suggestions! If you encounter any further challenges, please do not hesitate to contact me.

Best, Yushun

zyushun commented 1 month ago

Hi @lessw2020 !

We just updated a new version of Adam_mini.py. This version integrated your suggestions and some other minor refinements. Please remember to update this Adam_mini.py in your local codebase. You can call as follows:

from Adam_mini import Adam_mini

optimizer = Adam_mini(
        model = model, 
        lr = lr, 
        betas = (beta1,beta2), 
        eps = eps,
        weight_decay = weight_decay,
        model_sharding = True,
        n_feature = model_config.dim,
        n_head = model_config.n_heads,
        n_kv_head = model_config.n_kv_heads
    )

Thanks for all the great suggestions!

Best, Yushun

lessw2020 commented 1 month ago

Hi Yushun, Great, thanks for the updated version. Will give it a run tomorrow morning and update here. Hope you are doing well, Less