zyushun / Adam-mini

Code for Adam-mini: Use Fewer Learning Rates To Gain More https://arxiv.org/abs/2406.16793
316 stars 10 forks source link

Adam-mini <> DTensor #14

Closed awgu closed 3 months ago

awgu commented 3 months ago

Hi @zyushun! Thanks for your awesome work. I wanted to take some time to understand the requirements a new optimizer to work with DTensor, so I worked through your implementation.

The implementation is in the following (pardon some of my stylistic changes):

Adam-mini implementation without explicit collectives ``` import math from typing import Iterable, Tuple, Union, Optional import torch import torch.nn as nn class AdamWMini(torch.optim.Optimizer): def __init__( self, named_parameters: Iterable[Tuple[str, nn.Parameter]], lr: Union[float, torch.Tensor] = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0, *, dim: int = 2048, n_heads: int = 32, n_kv_heads: Optional[int] = None, ): self.dim = dim self.n_heads = n_heads if n_kv_heads is not None: assert n_heads % n_kv_heads == 0, f"{n_heads} {n_kv_heads}" self.n_kv_heads = n_kv_heads else: self.n_kv_heads = n_heads if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= betas[0] < 1.0: raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) if not 0.0 <= weight_decay: raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) if not self.dim == int(self.dim): raise ValueError("Invalid dim value: {}".format(self.dim)) if not self.n_heads == int(self.n_heads): raise ValueError("Invalid n_heads value: {}".format(self.n_heads)) if not self.n_kv_heads == int(self.n_kv_heads): raise ValueError("Invalid n_kv_heads value: {}".format(self.n_kv_heads)) optim_groups = [] count_embd = count_output = count_wq = count_wk = 0 for param_name, param in named_parameters: if not param.requires_grad: continue state = {} state["name"] = param_name state["params"] = param if "norm" in param_name or "ln_f" in param_name: state["weight_decay"] = 0.0 else: state["weight_decay"] = weight_decay if "embed" in param_name or "wte" in param_name or "embd" in param_name: count_embd += 1 if "lm_head.weight" in param_name or "output.weight" in param_name: count_output += 1 if "q_proj.weight" in param_name or "wq.weight" in param_name: count_wq += 1 assert (self.dim * self.dim) % self.n_heads == 0, f"{self.dim} {self.n_heads}" state["head_numel"] = self.dim * self.dim // self.n_heads if "k_proj.weight" in param_name or "wk.weight" in param_name: count_wk += 1 assert (self.dim * self.dim) % self.n_heads == 0, f"{self.dim} {self.n_heads}" state["head_numel"] = self.dim * self.dim // self.n_heads optim_groups.append(state) self.embd_names = {"embed", "embd", "wte", "lm_head.weight", "output.weight"} self.wqk_names = {"k_proj.weight", "q_proj.weight", "wq.weight", "wk.weight"} defaults = dict(lr=lr, beta1=betas[0], beta2=betas[1], eps=eps) super().__init__(optim_groups, defaults) @torch.no_grad() def step(self): for group in self.param_groups: beta1 = group["beta1"] beta2 = group["beta2"] lr = group["lr"] name = group["name"] eps = group["eps"] for p in group["params"]: if p.grad is None: continue state = self.state[p] if any(embd_name in name for embd_name in self.embd_names): if len(state) == 0: state["m"] = torch.zeros_like(p, dtype=torch.float32) state["step"] = 0 state["v"] = torch.zeros_like(p, dtype=torch.float32) grad = p.grad.to(torch.float32) state["v"].mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) state["step"] += 1 if group["weight_decay"] > 0.0: p.mul_(1 - lr * group["weight_decay"]) state["m"].lerp_(grad, 1 - beta1) bias_correction_1 = 1 - beta1 ** state["step"] bias_correction_2 = 1 - beta2 ** state["step"] bias_correction_2_sqrt = math.sqrt(bias_correction_2) h = (state["v"].sqrt() / bias_correction_2_sqrt).add_(eps) stepsize = lr / bias_correction_1 p.addcdiv_(state["m"], h, value=-stepsize) elif any(wqk_name in name for wqk_name in self.wqk_names): dim = group["head_numel"] if len(state) == 0: m = torch.zeros_like(p, dtype=torch.float32) state["m"] = m.view(-1, dim) state["head"] = state["m"].size(0) state["step"] = 0 # NOTE: We must use `zeros_like` for vmean to be a # DTensor (not `torch.Tensor`) for DTensor parameters. # state["vmean"] = torch.zeros(state["head"]) state["vmean"] = torch.zeros_like(state["m"][0:state["head"], 0:1]) grad = p.grad.to(torch.float32) head = state["head"] grad = grad.view(head, dim) tmp_lr = torch.mean(grad * grad, dim=1, keepdim=True) state["vmean"].mul_(beta2).add_(tmp_lr, alpha=1 - beta2) state["step"] += 1 if group["weight_decay"] > 0.0: p.mul_(1 - lr * group["weight_decay"]) state["m"].lerp_(grad, 1 - beta1) bias_correction_1 = 1 - beta1 ** state["step"] bias_correction_2 = 1 - beta2 ** state["step"] bias_correction_2_sqrt = math.sqrt(bias_correction_2) h = (state["vmean"].sqrt() / bias_correction_2_sqrt).add_(eps) stepsize = ((1 / bias_correction_1) / h).view(head, 1) update = (state["m"] * stepsize).view(p.size()) update.mul_(lr) p.add_(-update) else: if len(state) == 0: dim = p.numel() state["m"] = torch.zeros_like(p, dtype=torch.float32) state["step"] = 0 # NOTE: We must use `new_zeros` for vmean to be a # DTensor (not `torch.Tensor`) for DTensor parameters. # state["vmean"] = torch.zeros(1, device=p.device) state["vmean"] = p.new_zeros(1) state["dim"] = dim grad = p.grad.to(torch.float32) tmp_lr = torch.sum(grad * grad) tmp_lr = tmp_lr / state["dim"] if group["weight_decay"] > 0.0: p.mul_(1 - lr * group["weight_decay"]) state["step"] += 1 state["m"].lerp_(grad, 1 - beta1) bias_correction_1 = 1 - beta1 ** state["step"] bias_correction_2 = 1 - beta2 ** state["step"] bias_correction_2_sqrt = math.sqrt(bias_correction_2) state["vmean"].mul_(beta2).add_(tmp_lr, alpha=1 - beta2) h = (state["vmean"].sqrt() / bias_correction_2_sqrt).add_(eps) stepsize = (1 / bias_correction_1) / h update = state["m"] * (stepsize.to(state["m"].device)) update.mul_(lr) p.add_(-update) ```

The main requirement is like you already had: we must use zeros_like instead of zeros to preserve DTensor-ness. However, if we use DTensor, then we should not need any explicit communication written in the optimizer code: https://github.com/zyushun/Adam-mini/blob/bb02b74a4965f4e67661784e9ab3cc2a68c9eb52/Adam_mini.py#L230

https://github.com/zyushun/Adam-mini/blob/bb02b74a4965f4e67661784e9ab3cc2a68c9eb52/Adam_mini.py#L264-L266

Some other stylistic changes:

Since this is a rewrite of your implementation, there might be bugs. I ran a sanity check on Llama3-8B on 8 GPUs with local batch size 1 with default torchtitan hyperparameters (lr=3-e4, betas=(0.9, 0.95), weight_decay=0.1):

Screenshot 2024-07-15 at 10 50 45 AM

Purple is Adam-mini, and orange is AdamW. Let me know if the curves were expected to match better than this even with default hyperparameters.

71.94 GiB peak reserved memory (AdamW) -> 65.75 GiB peak reserved memory (Adam-mini)

I opened a PR in torchtitan to show the code changes: https://github.com/pytorch/torchtitan/pull/459 Ignore the PR underneath this one in the stack -- it is related to some other changes I am testing

zyushun commented 3 months ago

Hi @awgu ! Wow! Great thanks! We will definitely carefully read your implementation! 🚀🚀

Thanks so much for the refinement! It would help a lot!

We will update here soon.

zyushun commented 3 months ago

Hi @awgu !

We double-checked your code and it works well on my side. Thanks for all the professional changes! Now the code looks much better with your help.

Also, thanks for helping merge Adam-mini into Torchtitan. It really means a lot to us! Torchtitan is the simplest, fastest, and most-readable pre-training codebase I've ever used. It is truly a great contribution to the community. Thanks for the great work. It would help many many future AI researchers. We feel truly honored that Adam-mini can be inherently supported by such a great codebase.

Built upon your revised code, we further make some necessary changes to make it compatible with general Tensors. For instance, we put all_reduce back as it is needed for general Tensor. The revised version is now pip install-able with the code here. Please feel free to contact us if you have any further suggestions on the style and others.

Regarding the curves that you re-produce: they looks fine. You will see Adam-mini gradually overlaps with (or slightly outperform) AdamW if you train more steps, like 5k or 10k.

Thanks again for all the support and suggestions!

zyushun commented 2 months ago

Dear Andrew: Sorry for the delay! Sorry that I missed your email during my extensive schedule in the last two weeks (ICML conference + NeurIPS rebuttal). It is really good to see that Adam-mini can save more memory than Adam-8bit! I noticed your pull request to integrate Adam-mini to torchtitan https://github.com/pytorch/torchtitan/pull/459. It seems that the pull request was not successful so I am trying to help out here. Perhaps you can try our new version with "pip install" (modified based on your stylist changes) https://github.com/zyushun/Adam-mini, maybe it would help. Additionally, we have acknowledged you and Dr. Wright as important contributors to our repo (mentioned in the "Acknowdgement" section in readme). Thanks for all the help. Your support means a lot to us. Please do not hestitate to contact me if you need any help. Sincerely, Yushun

2024年7月23日 09:14,Andrew Gu @.**@.>> 写道:

I did a quick 1000 step run with Adam4bit on Llama3-8B on c4 dataset on 8xH100s, and Adam4bit seems to be converging well. The memory savings are not that significant though. Looking at the snapshots, it is mainly a memory fragmentation issue.

Screenshot.2024-07-22.at.8.57.38.PM.png (view on web)https://github.com/user-attachments/assets/eca4b88e-eab7-4b81-8af2-10af4d4aea71 Screenshot.2024-07-22.at.8.57.55.PM.png (view on web)https://github.com/user-attachments/assets/2ccc5d02-4097-40d7-8ca0-278971700bb6 Screenshot.2024-07-22.at.8.58.25.PM.png (view on web)https://github.com/user-attachments/assets/114dd00f-26ee-4a85-aed6-a61cd5f7393f torchtitan diff

diff --git a/train.py b/train.py index b7eee30..3d5a2d0 100644 --- a/train.py +++ b/train.py @@ -116,6 +116,11 @@ def build_optimizers(model_parts, job_config: JobConfig): optimizer = torch.optim.Adam(model.parameters(), optimizer_kwargs) elif name == "AdamW": optimizer = torch.optim.AdamW(model.parameters(), optimizer_kwargs)

Llama3-8B Config

torchtitan Config.toml

NOTE: this toml config is a preset for 64 A100 GPUs.

[job] dump_folder = "./outputs" description = "Llama 3 8B training"

[profiling] enable_profiling = true save_traces_folder = "profile_trace" profile_freq = 100

[metrics] log_freq = 10 enable_tensorboard = true save_tb_folder = "tb"

[model] name = "llama3" flavor = "8B" norm_type = "compiled_rmsnorm" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model"

[optimizer] name = "AdamW4bit" lr = 3e-4

[training] batch_size = 1 seq_len = 8192 warmup_steps = 200 # lr scheduler warm up max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 1 enable_float8_linear = false compile = true dataset = "c4"

[experimental] pipeline_parallel_degree = 1

[checkpoint] enable_checkpoint = false folder = "checkpoint" interval_type = "steps" interval = 500 model_weights_only = false export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint] mode = 'none' # ['none', 'selective', 'full'] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy

— Reply to this email directly, view it on GitHubhttps://github.com/zyushun/Adam-mini/issues/14#issuecomment-2244069165, or unsubscribehttps://github.com/notifications/unsubscribe-auth/AUMTGGBGTVQ6WBZJFAOXDQTZNWVABAVCNFSM6AAAAABK4W4P2SVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDENBUGA3DSMJWGU. You are receiving this because you were mentioned.Message ID: @.***>

awgu commented 2 months ago

@zyushun Hey! I think that it might be hard for Adam-mini to land in torchtitan given the philosophy of the repository. It is meant to be reference code exactly so that it is easy to fork it and try out new things like Adam-mini.

We can leave the PR open and also try to include it as an example in the torchtitan readme possibly. We are doing this for some other changes to the code like https://github.com/pytorch/torchtitan/pull/437.