microsoft / DeepSpeed

DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.
https://www.deepspeed.ai/
Apache License 2.0
34.85k stars 4.05k forks source link

[BUG] `zero_quantized_nontrainable_weights=True` when using PEFT+DeepSpeed with Mixed-Precision training using BF16 leads to `float != c10::BFloat16` error #4885

Closed pacman100 closed 4 months ago

pacman100 commented 9 months ago

Describe the bug zero_quantized_nontrainable_weights=True when using PEFT+DeepSpeed with Mixed-Precision training using BF16 leads to float != c10::BFloat16 error

To Reproduce Steps to reproduce the behavior:

  1. DeepSpeed Config: https://github.com/pacman100/DHS-LLM-Workshop/blob/main/chat_assistant/training/configs/ds_config_z3_lora.json
  2. Accelerate Config: https://github.com/pacman100/DHS-LLM-Workshop/blob/main/chat_assistant/training/configs/deepspeed_zeropp_lora_config.yaml
  3. Code: https://github.com/pacman100/DHS-LLM-Workshop/blob/main/chat_assistant/training/train.py
  4. Launch Command: https://github.com/pacman100/DHS-LLM-Workshop/blob/main/chat_assistant/training/run_peft_deepspeed_zeropp.sh
  5. Infrastructure: 8 80GB GPUs.
  6. Output logs with error:
    "/fsx/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    63     result = forward_call(*args, **kwargs)
    64   File "/fsx/sourab/transformers/src/transformers/models/mistral/modeling_mistral.py", line 356, in forward
    65     query_states = self.q_proj(hidden_states)
    66   File "/fsx/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    67     return self._call_impl(*args, **kwargs)
    68   File "/fsx/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    69     result = forward_call(*args, **kwargs)
    70   File "/fsx/sourab/peft/src/peft/tuners/lora/layer.py", line 309, in forward
    71     result = self.base_layer(x, *args, **kwargs)
    72   File "/fsx/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    73     return self._call_impl(*args, **kwargs)
    74   File "/fsx/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    75     result = forward_call(*args, **kwargs)
    76   File "/fsx/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 114, in forward
    77     return F.linear(input, self.weight, self.bias)
    78   File "/fsx/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/deepspeed/runtime/zero/linear.py", line 109, in zero3_linear_wrap
    79     return LinearFunctionForZeroStage3.apply(input, weight)
    80   File "/fsx/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/autograd/function.py", line 539, in apply
    81     return super().apply(*args, **kwargs)  # type: ignore[misc]
    82   File "/fsx/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py", line 113, in decorate_fwd
    83     return fwd(*args, **kwargs)
    84   File "/fsx/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/deepspeed/runtime/zero/linear.py", line 57, in forward
    85     output = input.matmul(weight.t())
    86 RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::BFloat16 

Expected behavior When using PEFT LoRA with DeepSpeed along with the feature zero_quantized_nontrainable_weights, it should lead to non-trainable weights being quantized resulting in a lot of memory savings. This would enable even larger model fine-tuning or large batch sizes.

ds_report output

--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-dev package with apt
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
fused_adam ............. [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_lion ............... [NO] ....... [OKAY]
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
evoformer_attn ......... [NO] ....... [NO]
fused_lamb ............. [NO] ....... [OKAY]
fused_lion ............. [NO] ....... [OKAY]
inference_core_ops ..... [NO] ....... [OKAY]
cutlass_ops ............ [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
ragged_device_ops ...... [NO] ....... [OKAY]
ragged_ops ............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.1
 [WARNING]  using untested triton version (2.1.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/fsx/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch']
torch version .................... 2.1.2+cu121
deepspeed install path ........... ['/fsx/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/deepspeed']
deepspeed info ................... 0.12.5, unknown, unknown
torch cuda version ............... 12.1
torch hip version ................ None
nvcc version ..................... 12.1
deepspeed wheel compiled w. ...... torch 2.1, cuda 12.1
shared memory (/dev/shm) size .... 999.99 GB

System info (please complete the following information):

Launcher context Accelerate launcher which internally uses the DeepSpeed launcher.

sam-h-bean commented 5 months ago

would love to see this fixed for training MOEs on deepspeed with quantization + bf16

StefanHeng commented 4 months ago

Same issue here. Training w/ BF16 + PeFT and Zero3++:

Stack trace:

Traceback (most recent call last):
  File "/nethome/yheng6/Co-Adaptation/src/co_adapt/reward_ensemble/stablelm_zephyr_3b/train.py", line 455, in <module>
    main()
  File "/nethome/yheng6/Co-Adaptation/src/co_adapt/reward_ensemble/stablelm_zephyr_3b/train.py", line 400, in main
    train_result = trainer.train(resume_from_checkpoint=checkpoint)
  File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/transformers/trainer.py", line 1859, in train
    return inner_training_loop(
  File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/transformers/trainer.py", line 2203, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/transformers/trainer.py", line 3138, in training_step
    loss = self.compute_loss(model, inputs)
  File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py", line 1081, in compute_loss
    loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
  File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py", line 1022, in get_batch_loss_metrics
    ) = self.concatenated_forward(model, batch)
  File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py", line 985, in concatenated_forward
    all_logits = model(
  File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1814, in forward
    loss = self.module(*inputs, **kwargs)
  File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/peft/peft_model.py", line 1129, in forward
    return self.base_model(
  File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 161, in forward
    return self.model.forward(*args, **kwargs)
  File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/transformers/models/stablelm/modeling_stablelm.py", line 1158, in forward
    outputs = self.model(
  File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/transformers/models/stablelm/modeling_stablelm.py", line 1026, in forward
    layer_outputs = self._gradient_checkpointing_func(
  File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/_compile.py", line 24, in inner
    return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
  File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
    return fn(*args, **kwargs)
  File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 458, in checkpoint
    ret = function(*args, **kwargs)
  File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/transformers/models/stablelm/modeling_stablelm.py", line 759, in forward
    self_attn_output, self_attn_weights, present_key_value = self.self_attn(
  File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/transformers/models/stablelm/modeling_stablelm.py", line 535, in forward
    query_states = self.q_proj(hidden_states)
  File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/peft/tuners/lora/layer.py", line 509, in forward
    result = result + lora_B(lora_A(dropout(x))) * scaling
  File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
  File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/deepspeed/runtime/zero/linear.py", line 109, in zero3_linear_wrap
    return LinearFunctionForZeroStage3.apply(input, weight)
  File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/autograd/function.py", line 539, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py", line 113, in decorate_fwd
    return fwd(*args, **kwargs)
  File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/deepspeed/runtime/zero/linear.py", line 57, in forward
    output = input.matmul(weight.t())
RuntimeError: expected mat1 and mat2 to have the same dtype, but got: c10::BFloat16 != c10::Half
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /nethome/yheng6/Co-Adaptation/src/co_adapt/reward_ensemble/stablelm_zephyr_3b/train.py:455 in    │
│ <module>                                                                                         │
│                                                                                                  │
│   452                                                                                            │
│   453                                                                                            │
│   454 if __name__ == "__main__":                                                                 │
│ ❱ 455 │   main()                                                                                 │
│   456                                                                                            │
│                                                                                                  │
│ /nethome/yheng6/Co-Adaptation/src/co_adapt/reward_ensemble/stablelm_zephyr_3b/train.py:400 in    │
│ main                                                                                             │
│                                                                                                  │
│   397 │   │   checkpoint = training_args.resume_from_checkpoint                                  │
│   398 │   elif last_checkpoint is not None:                                                      │
│   399 │   │   checkpoint = last_checkpoint                                                       │
│ ❱ 400 │   train_result = trainer.train(resume_from_checkpoint=checkpoint)                        │
│   401 │   metrics = train_result.metrics                                                         │
│   402 │   metrics["train_samples"] = len(raw_datasets["train"])                                  │
│   403 │   trainer.log_metrics("train", metrics)                                                  │
│                                                                                                  │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/transformers/trainer.py:18 │
│ 59 in train                                                                                      │
│                                                                                                  │
│   1856 │   │   │   finally:                                                                      │
│   1857 │   │   │   │   hf_hub_utils.enable_progress_bars()                                       │
│   1858 │   │   else:                                                                             │
│ ❱ 1859 │   │   │   return inner_training_loop(                                                   │
│   1860 │   │   │   │   args=args,                                                                │
│   1861 │   │   │   │   resume_from_checkpoint=resume_from_checkpoint,                            │
│   1862 │   │   │   │   trial=trial,                                                              │
│                                                                                                  │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/transformers/trainer.py:22 │
│ 03 in _inner_training_loop                                                                       │
│                                                                                                  │
│   2200 │   │   │   │   │   self.control = self.callback_handler.on_step_begin(args, self.state,  │
│   2201 │   │   │   │                                                                             │
│   2202 │   │   │   │   with self.accelerator.accumulate(model):                                  │
│ ❱ 2203 │   │   │   │   │   tr_loss_step = self.training_step(model, inputs)                      │
│   2204 │   │   │   │                                                                             │
│   2205 │   │   │   │   if (                                                                      │
│   2206 │   │   │   │   │   args.logging_nan_inf_filter                                           │
│                                                                                                  │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/transformers/trainer.py:31 │
│ 38 in training_step                                                                              │
│                                                                                                  │
│   3135 │   │   │   return loss_mb.reduce_mean().detach().to(self.args.device)                    │
│   3136 │   │                                                                                     │
│   3137 │   │   with self.compute_loss_context_manager():                                         │
│ ❱ 3138 │   │   │   loss = self.compute_loss(model, inputs)                                       │
│   3139 │   │                                                                                     │
│   3140 │   │   if self.args.n_gpu > 1:                                                           │
│   3141 │   │   │   loss = loss.mean()  # mean() to average on multi-gpu parallel training        │
│                                                                                                  │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py │
│ :1081 in compute_loss                                                                            │
│                                                                                                  │
│   1078 │   │   compute_loss_context_manager = torch.cuda.amp.autocast if self._peft_has_been_ca  │
│   1079 │   │                                                                                     │
│   1080 │   │   with compute_loss_context_manager():                                              │
│ ❱ 1081 │   │   │   loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train  │
│   1082 │   │                                                                                     │
│   1083 │   │   # Make sure to move the loss to the device the original accumulating loss is at   │
│   1084 │   │   loss = loss.to(self.args.device)                                                  │
│                                                                                                  │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py │
│ :1022 in get_batch_loss_metrics                                                                  │
│                                                                                                  │
│   1019 │   │   │   policy_rejected_logps,                                                        │
│   1020 │   │   │   policy_chosen_logits,                                                         │
│   1021 │   │   │   policy_rejected_logits,                                                       │
│ ❱ 1022 │   │   ) = self.concatenated_forward(model, batch)                                       │
│   1023 │   │                                                                                     │
│   1024 │   │   # if reference_chosen_logps and reference_rejected_logps in batch use them, othe  │
│   1025 │   │   if "reference_chosen_logps" in batch and "reference_rejected_logps" in batch:     │
│                                                                                                  │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py │
│ :985 in concatenated_forward                                                                     │
│                                                                                                  │
│    982 │   │   │   if self.is_encoder_decoder                                                    │
│    983 │   │   │   else {}                                                                       │
│    984 │   │   )                                                                                 │
│ ❱  985 │   │   all_logits = model(                                                               │
│    986 │   │   │   concatenated_batch["concatenated_input_ids"],                                 │
│    987 │   │   │   attention_mask=concatenated_batch["concatenated_attention_mask"],             │
│    988 │   │   │   use_cache=False,                                                              │
│                                                                                                  │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1518 in _wrapped_call_impl                                                                      │
│                                                                                                  │
│   1515 │   │   if self._compiled_call_impl is not None:                                          │
│   1516 │   │   │   return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]        │
│   1517 │   │   else:                                                                             │
│ ❱ 1518 │   │   │   return self._call_impl(*args, **kwargs)                                       │
│   1519 │                                                                                         │
│   1520 │   def _call_impl(self, *args, **kwargs):                                                │
│   1521 │   │   forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.fo  │
│                                                                                                  │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1527 in _call_impl                                                                              │
│                                                                                                  │
│   1524 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1525 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1526 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1527 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1528 │   │                                                                                     │
│   1529 │   │   try:                                                                              │
│   1530 │   │   │   result = None                                                                 │
│                                                                                                  │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/deepspeed/utils/nvtx.py:15 │
│ in wrapped_fn                                                                                    │
│                                                                                                  │
│   12 │                                                                                           │
│   13 │   def wrapped_fn(*args, **kwargs):                                                        │
│   14 │   │   get_accelerator().range_push(func.__qualname__)                                     │
│ ❱ 15 │   │   ret_val = func(*args, **kwargs)                                                     │
│   16 │   │   get_accelerator().range_pop()                                                       │
│   17 │   │   return ret_val                                                                      │
│   18                                                                                             │
│                                                                                                  │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/deepspeed/runtime/engine.p │
│ y:1814 in forward                                                                                │
│                                                                                                  │
│   1811 │   │   if self.fp16_auto_cast():                                                         │
│   1812 │   │   │   inputs = self._cast_inputs_half(inputs)                                       │
│   1813 │   │                                                                                     │
│ ❱ 1814 │   │   loss = self.module(*inputs, **kwargs)                                             │
│   1815 │   │                                                                                     │
│   1816 │   │   if self.zero_optimization_partition_weights():                                    │
│   1817 │   │   │   # Disable automated discovery of external parameters                          │
│                                                                                                  │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1518 in _wrapped_call_impl                                                                      │
│                                                                                                  │
│   1515 │   │   if self._compiled_call_impl is not None:                                          │
│   1516 │   │   │   return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]        │
│   1517 │   │   else:                                                                             │
│ ❱ 1518 │   │   │   return self._call_impl(*args, **kwargs)                                       │
│   1519 │                                                                                         │
│   1520 │   def _call_impl(self, *args, **kwargs):                                                │
│   1521 │   │   forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.fo  │
│                                                                                                  │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1568 in _call_impl                                                                              │
│                                                                                                  │
│   1565 │   │   │   │   bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hoo  │
│   1566 │   │   │   │   args = bw_hook.setup_input_hook(args)                                     │
│   1567 │   │   │                                                                                 │
│ ❱ 1568 │   │   │   result = forward_call(*args, **kwargs)                                        │
│   1569 │   │   │   if _global_forward_hooks or self._forward_hooks:                              │
│   1570 │   │   │   │   for hook_id, hook in (                                                    │
│   1571 │   │   │   │   │   *_global_forward_hooks.items(),                                       │
│                                                                                                  │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/peft/peft_model.py:1129 in │
│ forward                                                                                          │
│                                                                                                  │
│   1126 │   │   │                                                                                 │
│   1127 │   │   │   with self._enable_peft_forward_hooks(**kwargs):                               │
│   1128 │   │   │   │   kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_  │
│ ❱ 1129 │   │   │   │   return self.base_model(                                                   │
│   1130 │   │   │   │   │   input_ids=input_ids,                                                  │
│   1131 │   │   │   │   │   attention_mask=attention_mask,                                        │
│   1132 │   │   │   │   │   inputs_embeds=inputs_embeds,                                          │
│                                                                                                  │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1518 in _wrapped_call_impl                                                                      │
│                                                                                                  │
│   1515 │   │   if self._compiled_call_impl is not None:                                          │
│   1516 │   │   │   return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]        │
│   1517 │   │   else:                                                                             │
│ ❱ 1518 │   │   │   return self._call_impl(*args, **kwargs)                                       │
│   1519 │                                                                                         │
│   1520 │   def _call_impl(self, *args, **kwargs):                                                │
│   1521 │   │   forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.fo  │
│                                                                                                  │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1568 in _call_impl                                                                              │
│                                                                                                  │
│   1565 │   │   │   │   bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hoo  │
│   1566 │   │   │   │   args = bw_hook.setup_input_hook(args)                                     │
│   1567 │   │   │                                                                                 │
│ ❱ 1568 │   │   │   result = forward_call(*args, **kwargs)                                        │
│   1569 │   │   │   if _global_forward_hooks or self._forward_hooks:                              │
│   1570 │   │   │   │   for hook_id, hook in (                                                    │
│   1571 │   │   │   │   │   *_global_forward_hooks.items(),                                       │
│                                                                                                  │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/peft/tuners/tuners_utils.p │
│ y:161 in forward                                                                                 │
│                                                                                                  │
│   158 │   │   return self.active_adapter                                                         │
│   159 │                                                                                          │
│   160 │   def forward(self, *args: Any, **kwargs: Any):                                          │
│ ❱ 161 │   │   return self.model.forward(*args, **kwargs)                                         │
│   162 │                                                                                          │
│   163 │   @abstractmethod                                                                        │
│   164 │   def _prepare_adapter_config(self, peft_config: PeftConfig, model_config: dict) -> Pe   │
│                                                                                                  │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/transformers/models/stable │
│ lm/modeling_stablelm.py:1158 in forward                                                          │
│                                                                                                  │
│   1155 │   │   )                                                                                 │
│   1156 │   │   return_dict = return_dict if return_dict is not None else self.config.use_return  │
│   1157 │   │                                                                                     │
│ ❱ 1158 │   │   outputs = self.model(                                                             │
│   1159 │   │   │   input_ids=input_ids,                                                          │
│   1160 │   │   │   attention_mask=attention_mask,                                                │
│   1161 │   │   │   position_ids=position_ids,                                                    │
│                                                                                                  │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1518 in _wrapped_call_impl                                                                      │
│                                                                                                  │
│   1515 │   │   if self._compiled_call_impl is not None:                                          │
│   1516 │   │   │   return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]        │
│   1517 │   │   else:                                                                             │
│ ❱ 1518 │   │   │   return self._call_impl(*args, **kwargs)                                       │
│   1519 │                                                                                         │
│   1520 │   def _call_impl(self, *args, **kwargs):                                                │
│   1521 │   │   forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.fo  │
│                                                                                                  │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1568 in _call_impl                                                                              │
│                                                                                                  │
│   1565 │   │   │   │   bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hoo  │
│   1566 │   │   │   │   args = bw_hook.setup_input_hook(args)                                     │
│   1567 │   │   │                                                                                 │
│ ❱ 1568 │   │   │   result = forward_call(*args, **kwargs)                                        │
│   1569 │   │   │   if _global_forward_hooks or self._forward_hooks:                              │
│   1570 │   │   │   │   for hook_id, hook in (                                                    │
│   1571 │   │   │   │   │   *_global_forward_hooks.items(),                                       │
│                                                                                                  │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/transformers/models/stable │
│ lm/modeling_stablelm.py:1026 in forward                                                          │
│                                                                                                  │
│   1023 │   │   │   │   all_hidden_states += (hidden_states,)                                     │
│   1024 │   │   │                                                                                 │
│   1025 │   │   │   if self.gradient_checkpointing and self.training:                             │
│ ❱ 1026 │   │   │   │   layer_outputs = self._gradient_checkpointing_func(                        │
│   1027 │   │   │   │   │   decoder_layer.__call__,                                               │
│   1028 │   │   │   │   │   hidden_states,                                                        │
│   1029 │   │   │   │   │   attention_mask,                                                       │
│                                                                                                  │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/_compile.py:24 in    │
│ inner                                                                                            │
│                                                                                                  │
│   21 │   │   def inner(*args, **kwargs):                                                         │
│   22 │   │   │   import torch._dynamo                                                            │
│   23 │   │   │                                                                                   │
│ ❱ 24 │   │   │   return torch._dynamo.disable(fn, recursive)(*args, **kwargs)                    │
│   25 │   │                                                                                       │
│   26 │   │   return inner                                                                        │
│   27 │   else:                                                                                   │
│                                                                                                  │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/_dynamo/eval_frame.p │
│ y:328 in _fn                                                                                     │
│                                                                                                  │
│    325 │   │   │   dynamic_ctx = enable_dynamic(self.dynamic, self.export)                       │
│    326 │   │   │   dynamic_ctx.__enter__()                                                       │
│    327 │   │   │   try:                                                                          │
│ ❱  328 │   │   │   │   return fn(*args, **kwargs)                                                │
│    329 │   │   │   finally:                                                                      │
│    330 │   │   │   │   set_eval_frame(prior)                                                     │
│    331 │   │   │   │   dynamic_ctx.__exit__(None, None, None)                                    │
│                                                                                                  │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/_dynamo/external_uti │
│ ls.py:17 in inner                                                                                │
│                                                                                                  │
│   14 │                                                                                           │
│   15 │   @functools.wraps(fn)                                                                    │
│   16 │   def inner(*args, **kwargs):                                                             │
│ ❱ 17 │   │   return fn(*args, **kwargs)                                                          │
│   18 │                                                                                           │
│   19 │   return inner                                                                            │
│   20                                                                                             │
│                                                                                                  │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/utils/checkpoint.py: │
│ 458 in checkpoint                                                                                │
│                                                                                                  │
│    455 │   │   )                                                                                 │
│    456 │   │   # Runs pre-forward logic                                                          │
│    457 │   │   next(gen)                                                                         │
│ ❱  458 │   │   ret = function(*args, **kwargs)                                                   │
│    459 │   │   # Runs post-forward logic                                                         │
│    460 │   │   try:                                                                              │
│    461 │   │   │   next(gen)                                                                     │
│                                                                                                  │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1518 in _wrapped_call_impl                                                                      │
│                                                                                                  │
│   1515 │   │   if self._compiled_call_impl is not None:                                          │
│   1516 │   │   │   return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]        │
│   1517 │   │   else:                                                                             │
│ ❱ 1518 │   │   │   return self._call_impl(*args, **kwargs)                                       │
│   1519 │                                                                                         │
│   1520 │   def _call_impl(self, *args, **kwargs):                                                │
│   1521 │   │   forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.fo  │
│                                                                                                  │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1568 in _call_impl                                                                              │
│                                                                                                  │
│   1565 │   │   │   │   bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hoo  │
│   1566 │   │   │   │   args = bw_hook.setup_input_hook(args)                                     │
│   1567 │   │   │                                                                                 │
│ ❱ 1568 │   │   │   result = forward_call(*args, **kwargs)                                        │
│   1569 │   │   │   if _global_forward_hooks or self._forward_hooks:                              │
│   1570 │   │   │   │   for hook_id, hook in (                                                    │
│   1571 │   │   │   │   │   *_global_forward_hooks.items(),                                       │
│                                                                                                  │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/transformers/models/stable │
│ lm/modeling_stablelm.py:759 in forward                                                           │
│                                                                                                  │
│    756 │   │   hidden_states = self.input_layernorm(hidden_states)                               │
│    757 │   │                                                                                     │
│    758 │   │   # Self Attention                                                                  │
│ ❱  759 │   │   self_attn_output, self_attn_weights, present_key_value = self.self_attn(          │
│    760 │   │   │   hidden_states=hidden_states,                                                  │
│    761 │   │   │   attention_mask=attention_mask,                                                │
│    762 │   │   │   position_ids=position_ids,                                                    │
│                                                                                                  │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1518 in _wrapped_call_impl                                                                      │
│                                                                                                  │
│   1515 │   │   if self._compiled_call_impl is not None:                                          │
│   1516 │   │   │   return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]        │
│   1517 │   │   else:                                                                             │
│ ❱ 1518 │   │   │   return self._call_impl(*args, **kwargs)                                       │
│   1519 │                                                                                         │
│   1520 │   def _call_impl(self, *args, **kwargs):                                                │
│   1521 │   │   forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.fo  │
│                                                                                                  │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1568 in _call_impl                                                                              │
│                                                                                                  │
│   1565 │   │   │   │   bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hoo  │
│   1566 │   │   │   │   args = bw_hook.setup_input_hook(args)                                     │
│   1567 │   │   │                                                                                 │
│ ❱ 1568 │   │   │   result = forward_call(*args, **kwargs)                                        │
│   1569 │   │   │   if _global_forward_hooks or self._forward_hooks:                              │
│   1570 │   │   │   │   for hook_id, hook in (                                                    │
│   1571 │   │   │   │   │   *_global_forward_hooks.items(),                                       │
│                                                                                                  │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/transformers/models/stable │
│ lm/modeling_stablelm.py:535 in forward                                                           │
│                                                                                                  │
│    532 │   │                                                                                     │
│    533 │   │   bsz, q_len, _ = hidden_states.size()                                              │
│    534 │   │                                                                                     │
│ ❱  535 │   │   query_states = self.q_proj(hidden_states)                                         │
│    536 │   │   key_states = self.k_proj(hidden_states)                                           │
│    537 │   │   value_states = self.v_proj(hidden_states)                                         │
│    538                                                                                           │
│                                                                                                  │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1518 in _wrapped_call_impl                                                                      │
│                                                                                                  │
│   1515 │   │   if self._compiled_call_impl is not None:                                          │
│   1516 │   │   │   return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]        │
│   1517 │   │   else:                                                                             │
│ ❱ 1518 │   │   │   return self._call_impl(*args, **kwargs)                                       │
│   1519 │                                                                                         │
│   1520 │   def _call_impl(self, *args, **kwargs):                                                │
│   1521 │   │   forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.fo  │
│                                                                                                  │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1568 in _call_impl                                                                              │
│                                                                                                  │
│   1565 │   │   │   │   bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hoo  │
│   1566 │   │   │   │   args = bw_hook.setup_input_hook(args)                                     │
│   1567 │   │   │                                                                                 │
│ ❱ 1568 │   │   │   result = forward_call(*args, **kwargs)                                        │
│   1569 │   │   │   if _global_forward_hooks or self._forward_hooks:                              │
│   1570 │   │   │   │   for hook_id, hook in (                                                    │
│   1571 │   │   │   │   │   *_global_forward_hooks.items(),                                       │
│                                                                                                  │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/peft/tuners/lora/layer.py: │
│ 509 in forward                                                                                   │
│                                                                                                  │
│    506 │   │   │   │   x = x.to(lora_A.weight.dtype)                                             │
│    507 │   │   │   │                                                                             │
│    508 │   │   │   │   if not self.use_dora[active_adapter]:                                     │
│ ❱  509 │   │   │   │   │   result = result + lora_B(lora_A(dropout(x))) * scaling                │
│    510 │   │   │   │   else:                                                                     │
│    511 │   │   │   │   │   x = dropout(x)                                                        │
│    512 │   │   │   │   │   result = result + self._apply_dora(x, lora_A, lora_B, scaling, activ  │
│                                                                                                  │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1518 in _wrapped_call_impl                                                                      │
│                                                                                                  │
│   1515 │   │   if self._compiled_call_impl is not None:                                          │
│   1516 │   │   │   return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]        │
│   1517 │   │   else:                                                                             │
│ ❱ 1518 │   │   │   return self._call_impl(*args, **kwargs)                                       │
│   1519 │                                                                                         │
│   1520 │   def _call_impl(self, *args, **kwargs):                                                │
│   1521 │   │   forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.fo  │
│                                                                                                  │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1568 in _call_impl                                                                              │
│                                                                                                  │
│   1565 │   │   │   │   bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hoo  │
│   1566 │   │   │   │   args = bw_hook.setup_input_hook(args)                                     │
│   1567 │   │   │                                                                                 │
│ ❱ 1568 │   │   │   result = forward_call(*args, **kwargs)                                        │
│   1569 │   │   │   if _global_forward_hooks or self._forward_hooks:                              │
│   1570 │   │   │   │   for hook_id, hook in (                                                    │
│   1571 │   │   │   │   │   *_global_forward_hooks.items(),                                       │
│                                                                                                  │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/linear.py │
│ :114 in forward                                                                                  │
│                                                                                                  │
│   111 │   │   │   init.uniform_(self.bias, -bound, bound)                                        │
│   112 │                                                                                          │
│   113 │   def forward(self, input: Tensor) -> Tensor:                                            │
│ ❱ 114 │   │   return F.linear(input, self.weight, self.bias)                                     │
│   115 │                                                                                          │
│   116 │   def extra_repr(self) -> str:                                                           │
│   117 │   │   return f'in_features={self.in_features}, out_features={self.out_features}, bias=   │
│                                                                                                  │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/deepspeed/runtime/zero/lin │
│ ear.py:109 in zero3_linear_wrap                                                                  │
│                                                                                                  │
│   106                                                                                            │
│   107 def zero3_linear_wrap(input, weight, bias=None):                                           │
│   108 │   if bias is None:                                                                       │
│ ❱ 109 │   │   return LinearFunctionForZeroStage3.apply(input, weight)                            │
│   110 │   else:                                                                                  │
│   111 │   │   return LinearFunctionForZeroStage3.apply(input, weight, bias)                      │
│   112                                                                                            │
│                                                                                                  │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/autograd/function.py │
│ :539 in apply                                                                                    │
│                                                                                                  │
│   536 │   │   if not torch._C._are_functorch_transforms_active():                                │
│   537 │   │   │   # See NOTE: [functorch vjp and autograd interaction]                           │
│   538 │   │   │   args = _functorch.utils.unwrap_dead_wrappers(args)                             │
│ ❱ 539 │   │   │   return super().apply(*args, **kwargs)  # type: ignore[misc]                    │
│   540 │   │                                                                                      │
│   541 │   │   if cls.setup_context == _SingleLevelFunction.setup_context:                        │
│   542 │   │   │   raise RuntimeError(                                                            │
│                                                                                                  │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/cuda/amp/autocast_mo │
│ de.py:113 in decorate_fwd                                                                        │
│                                                                                                  │
│   110 │   │   args[0]._dtype = torch.get_autocast_gpu_dtype()                                    │
│   111 │   │   if cast_inputs is None:                                                            │
│   112 │   │   │   args[0]._fwd_used_autocast = torch.is_autocast_enabled()                       │
│ ❱ 113 │   │   │   return fwd(*args, **kwargs)                                                    │
│   114 │   │   else:                                                                              │
│   115 │   │   │   autocast_context = torch.is_autocast_enabled()                                 │
│   116 │   │   │   args[0]._fwd_used_autocast = False                                             │
│                                                                                                  │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/deepspeed/runtime/zero/lin │
│ ear.py:57 in forward                                                                             │
│                                                                                                  │
│    54 │   │   │   # fused op is marginally faster                                                │
│    55 │   │   │   ret = torch.addmm(bias, input, weight.t())                                     │
│    56 │   │   else:                                                                              │
│ ❱  57 │   │   │   output = input.matmul(weight.t())                                              │
│    58 │   │   │   if bias is not None:                                                           │
│    59 │   │   │   │   output += bias                                                             │
│    60 │   │   │   ret = output                                                                   │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: expected mat1 and mat2 to have the same dtype, but got: c10::BFloat16 != c10::Half

Zero config:

{
    "fp16": {
        "enabled": false
    },
    "bf16": {
        "enabled": true
    },
    "zero_optimization": {
        "stage": 3,
        "offload_optimizer": {
            "device": "none",
            "nvme_path": "None"
        },
        "offload_param": {
            "device": "none",
            "nvme_path": "None"
        },
        "stage3_gather_16bit_weights_on_model_save": true,

        "reduce_bucket_size": "auto",

        "zero_quantized_weights": true,
        "zero_hpz_partition_size": 2,
        "zero_quantized_gradients": true,

        "contiguous_gradients": true,
        "overlap_comm": true
    },
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "gradient_accumulation_steps": "auto",
    "gradient_clipping": "auto",
    "steps_per_print": "inf"
}

Accelerate config:

compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
#  deepspeed_multinode_launcher: standard
#  offload_optimizer_device: none
#  offload_param_device: none
  zero3_init_flag: true
#  zero3_save_16bit_model: true
#  zero_stage: 3
  deepspeed_config_file: ./zero_configs/zero3++.json
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
#mixed_precision: bf16
num_machines: 1
#num_processes: 8
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Trainer on 2 A5000 GPUs.

GuanhuaWang commented 4 months ago

Hi Currently zero++ feature does not support for bf16 quantization, I suppose that is the root cause of this issue.

To fix it, you can Either use fp16 as dtype Or make "zero_quantized_weights": false and zero_quantized_gradients": false

sam-h-bean commented 3 months ago

@GuanhuaWang But because of some other training stability issues like this related to initializing llama in fp16 this makes training with zero++ for llama quite troublesome. Should we maybe reopen this issue and see about supporting bf16 in zero++?