EleutherAI / gpt-neox

An implementation of model parallel autoregressive transformers on GPUs, based on the Megatron and DeepSpeed libraries
https://www.eleuther.ai/
Apache License 2.0
6.9k stars 1k forks source link

The plot got from muP coord_check seems not horizontal, which may indicates there exits a bug in the muP implementation? #956

Open BaoYu0721 opened 1 year ago

BaoYu0721 commented 1 year ago

Bug Discription & To Reproduce The source code is from current main branch, and follow the instructions in the README-MUP.md until this step: image I encounter an error like this:

╭───────────────────── Traceback (most recent call last) ──────────────────────╮
│ /mnt/cache/baoyu/gpt-neox/train.py:27 in <module>                            │
│                                                                              │
│   24 │   neox_args.configure_distributed_args()                              │
│   25 │   neox_args.build_tokenizer()  # tokenizer needs to be build in train │
│   26 │   neox_args.initialize_tensorboard_writer()  # is initialized if tens │
│ ❱ 27 │   pretrain(neox_args=neox_args)                                       │
│   28                                                                         │
│                                                                              │
│ /mnt/cache/baoyu/gpt-neox/megatron/training.py:211 in pretrain               │
│                                                                              │
│   208 │   timers("train/valid/test data iterators").stop()                   │
│   209 │                                                                      │
│   210 │   if neox_args.use_mup and neox_args.coord_check:                    │
│ ❱ 211 │   │   mup_coord_check(neox_args, timers, lr_scheduler, train_data_it │
│   212 │                                                                      │
│   213 │   # Print setup timing.                                              │
│   214 │   print_rank_0("done with setups ...")                               │
│                                                                              │
│ /mnt/cache/baoyu/gpt-neox/megatron/training.py:154 in mup_coord_check        │
│                                                                              │
│   151 │   │   models[hidden_size] = lazy_model(hidden_size)                  │
│   152 │                                                                      │
│   153 │   neox_args.use_mup = True                                           │
│ ❱ 154 │   df_up = get_coord_data(                                            │
│   155 │   │   neox_args, timers, lr_scheduler, models, train_data_iterator,  │
│   156 │   )                                                                  │
│   157 │   neox_args.use_mup = False                                          │
│                                                                              │
│ /mnt/cache/baoyu/gpt-neox/megatron/mup_substitute.py:207 in get_coord_data   │
│                                                                              │
│   204 │   elif optimizer is None:                                            │
│   205 │   │   raise ValueError("optimizer should be sgd|adam|adamw or a cust │
│   206 │                                                                      │
│ ❱ 207 │   data = _get_coord_data(                                            │
│   208 │   │   neox_args, timers, lr_scheduler, models, dataloader, optcls, * │
│   209 │   )                                                                  │
│   210 │   data["optimizer"] = optimizer                                      │
│                                                                              │
│ /mnt/cache/baoyu/gpt-neox/megatron/mup_substitute.py:69 in _get_coord_data   │
│                                                                              │
│    66 │   │   │   │   │   )                                                  │
│    67 │   │   │   │                                                          │
│    68 │   │   │   │   # train for a step                                     │
│ ❱  69 │   │   │   │   loss_dict, skipped_iter = train_step(                  │
│    70 │   │   │   │   │   neox_args=neox_args,                               │
│    71 │   │   │   │   │   timers=timers,                                     │
│    72 │   │   │   │   │   data_iterator=dataloader,                          │
│                                                                              │
│ /mnt/cache/baoyu/gpt-neox/megatron/training.py:695 in train_step             │
│                                                                              │
│   692 │                                                                      │
│   693 │   # Pipeline parallelism schedules forward/backward/step             │
│   694 │   if neox_args.is_pipe_parallel:                                     │
│ ❱ 695 │   │   reduced_loss = train_step_pipe(                                │
│   696 │   │   │   neox_args=neox_args, timers=timers, model=model, data_iter │
│   697 │   │   )                                                              │
│   698 │   else:                                                              │
│                                                                              │
│ /mnt/cache/baoyu/gpt-neox/megatron/training.py:745 in train_step_pipe        │
│                                                                              │
│   742 │   """Single training step with DeepSpeed's pipeline parallel engine. │
│   743 │                                                                      │
│   744 │   assert neox_args.deepspeed                                         │
│ ❱ 745 │   loss = model.train_batch(data_iter=data_iterator)                  │
│   746 │   loss_dict = {"lm_loss": loss}                                      │
│   747 │   # Don't break Megatron's timers because we changed code paths.     │
│   748 │   for t in [                                                         │
│                                                                              │
│ /mnt/cache/baoyu/packages/usr/lib/python3.8/site-packages/deepspeed/runtime/ │
│ pipe/engine.py:336 in train_batch                                            │
│                                                                              │
│    333 │   │   sched = schedule.TrainSchedule(micro_batches=self.micro_batch │
│    334 │   │   │   │   │   │   │   │   │      stages=self.num_stages,        │
│    335 │   │   │   │   │   │   │   │   │      stage_id=self.stage_id)        │
│ ❱  336 │   │   self._exec_schedule(sched)                                    │
│    337 │   │   self.agg_train_loss = self._aggregate_total_loss()            │
│    338 │   │                                                                 │
│    339 │   │   self.timers('train_batch').stop()                             │
│                                                                              │
│ /mnt/cache/baoyu/packages/usr/lib/python3.8/site-packages/deepspeed/runtime/ │
│ pipe/engine.py:1307 in _exec_schedule                                        │
│                                                                              │
│   1304 │   │   │   │                                                         │
│   1305 │   │   │   │   # Equivalent to: self._exec_forward_pass(buffer_id=0) │
│   1306 │   │   │   │   self._exec_instr = MethodType(self._INSTRUCTION_MAP[t │
│ ❱ 1307 │   │   │   │   self._exec_instr(**cmd.kwargs)                        │
│   1308                                                                       │
│                                                                              │
│ /mnt/cache/baoyu/packages/usr/lib/python3.8/site-packages/deepspeed/runtime/ │
│ pipe/engine.py:627 in _exec_forward_pass                                     │
│                                                                              │
│    624 │   │   # tensor changes across batches                               │
│    625 │   │   self._zero_grads(inputs)                                      │
│    626 │   │                                                                 │
│ ❱  627 │   │   outputs = super().forward(inputs)                             │
│    628 │   │                                                                 │
│    629 │   │   # Reset activation checkpointing buffers.                     │
│    630 │   │   # Need to call this between evaluation iterations             │
│                                                                              │
│ /mnt/cache/baoyu/packages/usr/lib/python3.8/site-packages/deepspeed/utils/nv │
│ tx.py:15 in wrapped_fn                                                       │
│                                                                              │
│   12 │                                                                       │
│   13 │   def wrapped_fn(*args, **kwargs):                                    │
│   14 │   │   get_accelerator().range_push(func.__qualname__)                 │
│ ❱ 15 │   │   ret_val = func(*args, **kwargs)                                 │
│   16 │   │   get_accelerator().range_pop()                                   │
│   17 │   │   return ret_val                                                  │
│   18                                                                         │
│                                                                              │
│ /mnt/cache/baoyu/packages/usr/lib/python3.8/site-packages/deepspeed/runtime/ │
│ engine.py:1731 in forward                                                    │
│                                                                              │
│   1728 │   │   if self.fp16_auto_cast():                                     │
│   1729 │   │   │   inputs = self._cast_inputs_half(inputs)                   │
│   1730 │   │                                                                 │
│ ❱ 1731 │   │   loss = self.module(*inputs, **kwargs)                         │
│   1732 │   │                                                                 │
│   1733 │   │   if self.zero_optimization_partition_weights():                │
│   1734 │   │   │   # Disable automated discovery of external parameters      │
│                                                                              │
│ /mnt/cache/baoyu/packages/usr/lib/python3.8/site-packages/torch/nn/modules/m │
│ odule.py:1212 in _call_impl                                                  │
│                                                                              │
│   1209 │   │   │   bw_hook = hooks.BackwardHook(self, full_backward_hooks)   │
│   1210 │   │   │   input = bw_hook.setup_input_hook(input)                   │
│   1211 │   │                                                                 │
│ ❱ 1212 │   │   result = forward_call(*input, **kwargs)                       │
│   1213 │   │   if _global_forward_hooks or self._forward_hooks:              │
│   1214 │   │   │   for hook in (*_global_forward_hooks.values(), *self._forw │
│   1215 │   │   │   │   hook_result = hook(self, input, result)               │
│                                                                              │
│ /mnt/cache/baoyu/packages/usr/lib/python3.8/site-packages/deepspeed/runtime/ │
│ pipe/module.py:350 in forward                                                │
│                                                                              │
│   347 │   │   │   │   if self._is_checkpointable(funcs):                     │
│   348 │   │   │   │   │   x = self.activation_checkpoint_func(exec_range_fun │
│   349 │   │   │   │   else:                                                  │
│ ❱ 350 │   │   │   │   │   x = exec_range_func(start_idx, end_idx)(*x)        │
│   351 │   │   return x                                                       │
│   352 │                                                                      │
│   353 │   def _partition_layers(self, method='uniform'):                     │
│                                                                              │
│ /mnt/cache/baoyu/packages/usr/lib/python3.8/site-packages/deepspeed/runtime/ │
│ pipe/module.py:327 in exec_func                                              │
│                                                                              │
│   324 │   │   │   │   │   │   else:                                          │
│   325 │   │   │   │   │   │   │   ds_utils.set_random_seed(new_seed)         │
│   326 │   │   │   │   │                                                      │
│ ❱ 327 │   │   │   │   │   inputs = layer(inputs)                             │
│   328 │   │   │   │   return inputs                                          │
│   329 │   │   │                                                              │
│   330 │   │   │   return exec_func                                           │
│                                                                              │
│ /mnt/cache/baoyu/packages/usr/lib/python3.8/site-packages/torch/nn/modules/m │
│ odule.py:1215 in _call_impl                                                  │
│                                                                              │
│   1212 │   │   result = forward_call(*input, **kwargs)                       │
│   1213 │   │   if _global_forward_hooks or self._forward_hooks:              │
│   1214 │   │   │   for hook in (*_global_forward_hooks.values(), *self._forw │
│ ❱ 1215 │   │   │   │   hook_result = hook(self, input, result)               │
│   1216 │   │   │   │   if hook_result is not None:                           │
│   1217 │   │   │   │   │   result = hook_result                              │
│   1218                                                                       │
│                                                                              │
│ /mnt/cache/baoyu/packages/usr/lib/python3.8/site-packages/mup/coord_check.py │
│ :161 in f                                                                    │
│                                                                              │
│   158 │   │   │   │   for i, out in enumerate(output):                       │
│   159 │   │   │   │   │   _ret = copy(ret)                                   │
│   160 │   │   │   │   │   _ret['module'] += f':out[{i}]'                     │
│ ❱ 161 │   │   │   │   │   get_stat(_ret, out, output_fdict)                  │
│   162 │   │   │   elif isinstance(output, dict):                             │
│   163 │   │   │   │   for name, out in output.items():                       │
│   164 │   │   │   │   │   _ret = copy(ret)                                   │
│                                                                              │
│ /mnt/cache/baoyu/packages/usr/lib/python3.8/site-packages/mup/coord_check.py │
│ :145 in get_stat                                                             │
│                                                                              │
│   142 │   │   │   elif isinstance(x, torch.Tensor):                          │
│   143 │   │   │   │   _d = copy(d)                                           │
│   144 │   │   │   │   for fname, f in fdict.items():                         │
│ ❱ 145 │   │   │   │   │   _d[fname] = f(x).item()                            │
│   146 │   │   │   │   records.append(_d)                                     │
│   147 │   │   │   else:                                                      │
│   148 │   │   │   │   raise NotImplemented(f'Unexpected output type: {type(x │
│                                                                              │
│ /mnt/cache/baoyu/packages/usr/lib/python3.8/site-packages/mup/coord_check.py │
│ :44 in <lambda>                                                              │
│                                                                              │
│    41                                                                        │
│    42 #: dict of provided functions for use in coord check                   │
│    43 FDICT = {                                                              │
│ ❱  44 │   'l1': lambda x: torch.abs(x).mean(),                               │
│    45 │   'l2': lambda x: (x**2).mean()**0.5,                                │
│    46 │   'mean': lambda x: x.mean(),                                        │
│    47 │   'std': lambda x: x.std(),                                          │
╰──────────────────────────────────────────────────────────────────────────────╯
RuntimeError: mean(): could not infer output dtype. Input dtype must be either a
floating point or complex dtype. Got: Bool

This is caused by passing a Bool Tensor into the get_stat of mup(maybe the attention mask), but the mup library cannot handle it. In addition, we will also encount an error which is caused by passing None to the get_stat.

In order to solve this problem temporarily, I modify the source code in the file coord_check.py in mup like this: image

This time, coord_check ran successfully, it outputs many jpgs, one for each GPU, jpg from different GPUs looks very similar, so I just show one jpg for each paramerization.

Standard Parameterization: coord_check_sp 1

muP Parameterization: coord_check_up 1

The result looks weird, the SP is more horizontal than muP, which is not expected.

Expected behavior https://github.com/microsoft/mup#coord-check An expected behavior should looks like the plots in the above link, in which muP is very horizontal, while SP blows up.

Proposed solution I check the code related to mup, but don't have a proposal yet, I will try to keep checking it. Maybe contributors in the issue(https://github.com/EleutherAI/gpt-neox/issues/679) can give some comments? @nsarka @Quentin-Anthony @StellaAthena Thanks a lot!

Environment (please complete the following information):

StellaAthena commented 1 year ago

Thanks for raising this issue. It looks like you’re correct and we broke the implementation at some point.

One thing we really need to start doing (but haven’t been able to do due to manpower limitations) is build out a robust testing suite that verifies new major changes don’t break old features :S

BaoYu0721 commented 1 year ago

Thanks for your reply! I checkout to some other commits, such as the v2.0 release tag and earlier commit when deepspeed_main is merged into main (2b84f9af10eebdb82dfb956adc2cb54ba2f62344), and find the plots are similar to the discription above, maybe the bug is introduced even earlier?

ofivite commented 1 year ago

I was looking into the muP implementation in gpt-neox to contrast it with the Megatron-LM setup and accidentally found this issue :)

I am thinking, could LR schedule be the cause of the problem? By design it overrides the LR values per group (hence overwrites muP changes), and so the way muP scaling was introduced in AnnealingLR() as rescaling by group["width_mult"] in step() here. But I couldn't find that this key was added neither inside mup optimisers nor in gpt-neox codebase, so I am not sure that width_mult rescaling is applied at all.

Also, width_mult rescaling can be applied only for Adam-like optimisers and matrix-like params (as here), while for SGD the rescaling is with different multipliers, and so should be taken into account.

However, neither I found whether AnnealingLR schedule is actually applied during the training, so that might well be that my comment isn't really relevant to the observed behaviour.

marcobellagente93 commented 1 year ago

I think you are right @ofivite, i don't think the implementation was ever correct since the learning rate wasn't correctly setup from the beginning. After a long and thorough debugging I managed to pass at least the 2 basic sanity checks for mup:

  1. at same width mup doesn't do anything (since all shapes are the same it should coincide with SP)
  2. all the rest being fixed, you only get better by going wider

Issues I have found are:

marcobellagente93 commented 1 year ago

Found another bug, neox_args.use_mup is set to false before initializing models, which also sets their use_mup attribute to False and therefore always ignore multipliers

marcobellagente93 commented 1 year ago

And finally there seems to be a bug in the re-initialization of the output layer, after skipping that completely (it should be anyway in the flavour of Table 8) I'm getting these very nice and smooth horizontal lines

coord_check_up 0

ofivite commented 1 year ago

@marcobellagente93 Oh yes, now it's indeed nicely flat curves, great ! :)

marcobellagente93 commented 1 year ago

I'll make a PR as soon as I can

nsarka commented 1 year ago

Edit: email formatting did not work properly.

Hi, I’m interested to see the PR as well. When I originally made my PR, the curves looked as expected—flat for mup and blown up for SP.

The use_mup parameter is set to false so the weights never get initialized

Are you referring to it being false by default? When calculating the base shapes using mup, two models have to be instantiated, one with use_mup set to false and the other set to true. If I remember correctly, I set use_mup to false by default everywhere and enabled it only when mup was set to true in the config file. Then, when calculating base shapes it’s forced to false.

The learning rate is not scaled by width_mult

I saw some code being linked to in this thread that’s from the original mup repo from Microsoft. I had to bring in and modify a lot of it. I believe width_mult for the learning rate may have been set there.

marcobellagente93 commented 1 year ago

What I mean is that the main training loops with mup enabled does the following:

  1. set neox_args.use_mup to false
  2. initialize model
  3. set neox_args.use_mup back to true

but at step 1 all parameters get initialized with self.use_mup = neox_args.use_mup (which is false) and causes everything else to be wrong (multipliers not used, 1/d attention not used, ...)

nsarka commented 1 year ago

This behavior is expected--the weights are reinitialized using

https://github.com/EleutherAI/gpt-neox/blob/43ea51c2f3aeef2fc642ba401ce08844eb5a0240/megatron/training.py#L446

Or do you mean this function does not get called?