Closed pacman100 closed 4 months ago
would love to see this fixed for training MOEs on deepspeed with quantization + bf16
Same issue here. Training w/ BF16 + PeFT and Zero3++:
Stack trace:
Traceback (most recent call last):
File "/nethome/yheng6/Co-Adaptation/src/co_adapt/reward_ensemble/stablelm_zephyr_3b/train.py", line 455, in <module>
main()
File "/nethome/yheng6/Co-Adaptation/src/co_adapt/reward_ensemble/stablelm_zephyr_3b/train.py", line 400, in main
train_result = trainer.train(resume_from_checkpoint=checkpoint)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/transformers/trainer.py", line 1859, in train
return inner_training_loop(
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/transformers/trainer.py", line 2203, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/transformers/trainer.py", line 3138, in training_step
loss = self.compute_loss(model, inputs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py", line 1081, in compute_loss
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py", line 1022, in get_batch_loss_metrics
) = self.concatenated_forward(model, batch)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py", line 985, in concatenated_forward
all_logits = model(
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1814, in forward
loss = self.module(*inputs, **kwargs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
result = forward_call(*args, **kwargs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/peft/peft_model.py", line 1129, in forward
return self.base_model(
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
result = forward_call(*args, **kwargs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 161, in forward
return self.model.forward(*args, **kwargs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/transformers/models/stablelm/modeling_stablelm.py", line 1158, in forward
outputs = self.model(
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
result = forward_call(*args, **kwargs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/transformers/models/stablelm/modeling_stablelm.py", line 1026, in forward
layer_outputs = self._gradient_checkpointing_func(
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/_compile.py", line 24, in inner
return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
return fn(*args, **kwargs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 458, in checkpoint
ret = function(*args, **kwargs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
result = forward_call(*args, **kwargs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/transformers/models/stablelm/modeling_stablelm.py", line 759, in forward
self_attn_output, self_attn_weights, present_key_value = self.self_attn(
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
result = forward_call(*args, **kwargs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/transformers/models/stablelm/modeling_stablelm.py", line 535, in forward
query_states = self.q_proj(hidden_states)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
result = forward_call(*args, **kwargs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/peft/tuners/lora/layer.py", line 509, in forward
result = result + lora_B(lora_A(dropout(x))) * scaling
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
result = forward_call(*args, **kwargs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 114, in forward
return F.linear(input, self.weight, self.bias)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/deepspeed/runtime/zero/linear.py", line 109, in zero3_linear_wrap
return LinearFunctionForZeroStage3.apply(input, weight)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/autograd/function.py", line 539, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py", line 113, in decorate_fwd
return fwd(*args, **kwargs)
File "/nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/deepspeed/runtime/zero/linear.py", line 57, in forward
output = input.matmul(weight.t())
RuntimeError: expected mat1 and mat2 to have the same dtype, but got: c10::BFloat16 != c10::Half
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /nethome/yheng6/Co-Adaptation/src/co_adapt/reward_ensemble/stablelm_zephyr_3b/train.py:455 in │
│ <module> │
│ │
│ 452 │
│ 453 │
│ 454 if __name__ == "__main__": │
│ ❱ 455 │ main() │
│ 456 │
│ │
│ /nethome/yheng6/Co-Adaptation/src/co_adapt/reward_ensemble/stablelm_zephyr_3b/train.py:400 in │
│ main │
│ │
│ 397 │ │ checkpoint = training_args.resume_from_checkpoint │
│ 398 │ elif last_checkpoint is not None: │
│ 399 │ │ checkpoint = last_checkpoint │
│ ❱ 400 │ train_result = trainer.train(resume_from_checkpoint=checkpoint) │
│ 401 │ metrics = train_result.metrics │
│ 402 │ metrics["train_samples"] = len(raw_datasets["train"]) │
│ 403 │ trainer.log_metrics("train", metrics) │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/transformers/trainer.py:18 │
│ 59 in train │
│ │
│ 1856 │ │ │ finally: │
│ 1857 │ │ │ │ hf_hub_utils.enable_progress_bars() │
│ 1858 │ │ else: │
│ ❱ 1859 │ │ │ return inner_training_loop( │
│ 1860 │ │ │ │ args=args, │
│ 1861 │ │ │ │ resume_from_checkpoint=resume_from_checkpoint, │
│ 1862 │ │ │ │ trial=trial, │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/transformers/trainer.py:22 │
│ 03 in _inner_training_loop │
│ │
│ 2200 │ │ │ │ │ self.control = self.callback_handler.on_step_begin(args, self.state, │
│ 2201 │ │ │ │ │
│ 2202 │ │ │ │ with self.accelerator.accumulate(model): │
│ ❱ 2203 │ │ │ │ │ tr_loss_step = self.training_step(model, inputs) │
│ 2204 │ │ │ │ │
│ 2205 │ │ │ │ if ( │
│ 2206 │ │ │ │ │ args.logging_nan_inf_filter │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/transformers/trainer.py:31 │
│ 38 in training_step │
│ │
│ 3135 │ │ │ return loss_mb.reduce_mean().detach().to(self.args.device) │
│ 3136 │ │ │
│ 3137 │ │ with self.compute_loss_context_manager(): │
│ ❱ 3138 │ │ │ loss = self.compute_loss(model, inputs) │
│ 3139 │ │ │
│ 3140 │ │ if self.args.n_gpu > 1: │
│ 3141 │ │ │ loss = loss.mean() # mean() to average on multi-gpu parallel training │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py │
│ :1081 in compute_loss │
│ │
│ 1078 │ │ compute_loss_context_manager = torch.cuda.amp.autocast if self._peft_has_been_ca │
│ 1079 │ │ │
│ 1080 │ │ with compute_loss_context_manager(): │
│ ❱ 1081 │ │ │ loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train │
│ 1082 │ │ │
│ 1083 │ │ # Make sure to move the loss to the device the original accumulating loss is at │
│ 1084 │ │ loss = loss.to(self.args.device) │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py │
│ :1022 in get_batch_loss_metrics │
│ │
│ 1019 │ │ │ policy_rejected_logps, │
│ 1020 │ │ │ policy_chosen_logits, │
│ 1021 │ │ │ policy_rejected_logits, │
│ ❱ 1022 │ │ ) = self.concatenated_forward(model, batch) │
│ 1023 │ │ │
│ 1024 │ │ # if reference_chosen_logps and reference_rejected_logps in batch use them, othe │
│ 1025 │ │ if "reference_chosen_logps" in batch and "reference_rejected_logps" in batch: │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py │
│ :985 in concatenated_forward │
│ │
│ 982 │ │ │ if self.is_encoder_decoder │
│ 983 │ │ │ else {} │
│ 984 │ │ ) │
│ ❱ 985 │ │ all_logits = model( │
│ 986 │ │ │ concatenated_batch["concatenated_input_ids"], │
│ 987 │ │ │ attention_mask=concatenated_batch["concatenated_attention_mask"], │
│ 988 │ │ │ use_cache=False, │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1518 in _wrapped_call_impl │
│ │
│ 1515 │ │ if self._compiled_call_impl is not None: │
│ 1516 │ │ │ return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] │
│ 1517 │ │ else: │
│ ❱ 1518 │ │ │ return self._call_impl(*args, **kwargs) │
│ 1519 │ │
│ 1520 │ def _call_impl(self, *args, **kwargs): │
│ 1521 │ │ forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.fo │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1527 in _call_impl │
│ │
│ 1524 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │
│ 1525 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1526 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1527 │ │ │ return forward_call(*args, **kwargs) │
│ 1528 │ │ │
│ 1529 │ │ try: │
│ 1530 │ │ │ result = None │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/deepspeed/utils/nvtx.py:15 │
│ in wrapped_fn │
│ │
│ 12 │ │
│ 13 │ def wrapped_fn(*args, **kwargs): │
│ 14 │ │ get_accelerator().range_push(func.__qualname__) │
│ ❱ 15 │ │ ret_val = func(*args, **kwargs) │
│ 16 │ │ get_accelerator().range_pop() │
│ 17 │ │ return ret_val │
│ 18 │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/deepspeed/runtime/engine.p │
│ y:1814 in forward │
│ │
│ 1811 │ │ if self.fp16_auto_cast(): │
│ 1812 │ │ │ inputs = self._cast_inputs_half(inputs) │
│ 1813 │ │ │
│ ❱ 1814 │ │ loss = self.module(*inputs, **kwargs) │
│ 1815 │ │ │
│ 1816 │ │ if self.zero_optimization_partition_weights(): │
│ 1817 │ │ │ # Disable automated discovery of external parameters │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1518 in _wrapped_call_impl │
│ │
│ 1515 │ │ if self._compiled_call_impl is not None: │
│ 1516 │ │ │ return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] │
│ 1517 │ │ else: │
│ ❱ 1518 │ │ │ return self._call_impl(*args, **kwargs) │
│ 1519 │ │
│ 1520 │ def _call_impl(self, *args, **kwargs): │
│ 1521 │ │ forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.fo │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1568 in _call_impl │
│ │
│ 1565 │ │ │ │ bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hoo │
│ 1566 │ │ │ │ args = bw_hook.setup_input_hook(args) │
│ 1567 │ │ │ │
│ ❱ 1568 │ │ │ result = forward_call(*args, **kwargs) │
│ 1569 │ │ │ if _global_forward_hooks or self._forward_hooks: │
│ 1570 │ │ │ │ for hook_id, hook in ( │
│ 1571 │ │ │ │ │ *_global_forward_hooks.items(), │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/peft/peft_model.py:1129 in │
│ forward │
│ │
│ 1126 │ │ │ │
│ 1127 │ │ │ with self._enable_peft_forward_hooks(**kwargs): │
│ 1128 │ │ │ │ kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_ │
│ ❱ 1129 │ │ │ │ return self.base_model( │
│ 1130 │ │ │ │ │ input_ids=input_ids, │
│ 1131 │ │ │ │ │ attention_mask=attention_mask, │
│ 1132 │ │ │ │ │ inputs_embeds=inputs_embeds, │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1518 in _wrapped_call_impl │
│ │
│ 1515 │ │ if self._compiled_call_impl is not None: │
│ 1516 │ │ │ return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] │
│ 1517 │ │ else: │
│ ❱ 1518 │ │ │ return self._call_impl(*args, **kwargs) │
│ 1519 │ │
│ 1520 │ def _call_impl(self, *args, **kwargs): │
│ 1521 │ │ forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.fo │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1568 in _call_impl │
│ │
│ 1565 │ │ │ │ bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hoo │
│ 1566 │ │ │ │ args = bw_hook.setup_input_hook(args) │
│ 1567 │ │ │ │
│ ❱ 1568 │ │ │ result = forward_call(*args, **kwargs) │
│ 1569 │ │ │ if _global_forward_hooks or self._forward_hooks: │
│ 1570 │ │ │ │ for hook_id, hook in ( │
│ 1571 │ │ │ │ │ *_global_forward_hooks.items(), │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/peft/tuners/tuners_utils.p │
│ y:161 in forward │
│ │
│ 158 │ │ return self.active_adapter │
│ 159 │ │
│ 160 │ def forward(self, *args: Any, **kwargs: Any): │
│ ❱ 161 │ │ return self.model.forward(*args, **kwargs) │
│ 162 │ │
│ 163 │ @abstractmethod │
│ 164 │ def _prepare_adapter_config(self, peft_config: PeftConfig, model_config: dict) -> Pe │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/transformers/models/stable │
│ lm/modeling_stablelm.py:1158 in forward │
│ │
│ 1155 │ │ ) │
│ 1156 │ │ return_dict = return_dict if return_dict is not None else self.config.use_return │
│ 1157 │ │ │
│ ❱ 1158 │ │ outputs = self.model( │
│ 1159 │ │ │ input_ids=input_ids, │
│ 1160 │ │ │ attention_mask=attention_mask, │
│ 1161 │ │ │ position_ids=position_ids, │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1518 in _wrapped_call_impl │
│ │
│ 1515 │ │ if self._compiled_call_impl is not None: │
│ 1516 │ │ │ return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] │
│ 1517 │ │ else: │
│ ❱ 1518 │ │ │ return self._call_impl(*args, **kwargs) │
│ 1519 │ │
│ 1520 │ def _call_impl(self, *args, **kwargs): │
│ 1521 │ │ forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.fo │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1568 in _call_impl │
│ │
│ 1565 │ │ │ │ bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hoo │
│ 1566 │ │ │ │ args = bw_hook.setup_input_hook(args) │
│ 1567 │ │ │ │
│ ❱ 1568 │ │ │ result = forward_call(*args, **kwargs) │
│ 1569 │ │ │ if _global_forward_hooks or self._forward_hooks: │
│ 1570 │ │ │ │ for hook_id, hook in ( │
│ 1571 │ │ │ │ │ *_global_forward_hooks.items(), │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/transformers/models/stable │
│ lm/modeling_stablelm.py:1026 in forward │
│ │
│ 1023 │ │ │ │ all_hidden_states += (hidden_states,) │
│ 1024 │ │ │ │
│ 1025 │ │ │ if self.gradient_checkpointing and self.training: │
│ ❱ 1026 │ │ │ │ layer_outputs = self._gradient_checkpointing_func( │
│ 1027 │ │ │ │ │ decoder_layer.__call__, │
│ 1028 │ │ │ │ │ hidden_states, │
│ 1029 │ │ │ │ │ attention_mask, │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/_compile.py:24 in │
│ inner │
│ │
│ 21 │ │ def inner(*args, **kwargs): │
│ 22 │ │ │ import torch._dynamo │
│ 23 │ │ │ │
│ ❱ 24 │ │ │ return torch._dynamo.disable(fn, recursive)(*args, **kwargs) │
│ 25 │ │ │
│ 26 │ │ return inner │
│ 27 │ else: │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/_dynamo/eval_frame.p │
│ y:328 in _fn │
│ │
│ 325 │ │ │ dynamic_ctx = enable_dynamic(self.dynamic, self.export) │
│ 326 │ │ │ dynamic_ctx.__enter__() │
│ 327 │ │ │ try: │
│ ❱ 328 │ │ │ │ return fn(*args, **kwargs) │
│ 329 │ │ │ finally: │
│ 330 │ │ │ │ set_eval_frame(prior) │
│ 331 │ │ │ │ dynamic_ctx.__exit__(None, None, None) │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/_dynamo/external_uti │
│ ls.py:17 in inner │
│ │
│ 14 │ │
│ 15 │ @functools.wraps(fn) │
│ 16 │ def inner(*args, **kwargs): │
│ ❱ 17 │ │ return fn(*args, **kwargs) │
│ 18 │ │
│ 19 │ return inner │
│ 20 │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/utils/checkpoint.py: │
│ 458 in checkpoint │
│ │
│ 455 │ │ ) │
│ 456 │ │ # Runs pre-forward logic │
│ 457 │ │ next(gen) │
│ ❱ 458 │ │ ret = function(*args, **kwargs) │
│ 459 │ │ # Runs post-forward logic │
│ 460 │ │ try: │
│ 461 │ │ │ next(gen) │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1518 in _wrapped_call_impl │
│ │
│ 1515 │ │ if self._compiled_call_impl is not None: │
│ 1516 │ │ │ return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] │
│ 1517 │ │ else: │
│ ❱ 1518 │ │ │ return self._call_impl(*args, **kwargs) │
│ 1519 │ │
│ 1520 │ def _call_impl(self, *args, **kwargs): │
│ 1521 │ │ forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.fo │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1568 in _call_impl │
│ │
│ 1565 │ │ │ │ bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hoo │
│ 1566 │ │ │ │ args = bw_hook.setup_input_hook(args) │
│ 1567 │ │ │ │
│ ❱ 1568 │ │ │ result = forward_call(*args, **kwargs) │
│ 1569 │ │ │ if _global_forward_hooks or self._forward_hooks: │
│ 1570 │ │ │ │ for hook_id, hook in ( │
│ 1571 │ │ │ │ │ *_global_forward_hooks.items(), │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/transformers/models/stable │
│ lm/modeling_stablelm.py:759 in forward │
│ │
│ 756 │ │ hidden_states = self.input_layernorm(hidden_states) │
│ 757 │ │ │
│ 758 │ │ # Self Attention │
│ ❱ 759 │ │ self_attn_output, self_attn_weights, present_key_value = self.self_attn( │
│ 760 │ │ │ hidden_states=hidden_states, │
│ 761 │ │ │ attention_mask=attention_mask, │
│ 762 │ │ │ position_ids=position_ids, │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1518 in _wrapped_call_impl │
│ │
│ 1515 │ │ if self._compiled_call_impl is not None: │
│ 1516 │ │ │ return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] │
│ 1517 │ │ else: │
│ ❱ 1518 │ │ │ return self._call_impl(*args, **kwargs) │
│ 1519 │ │
│ 1520 │ def _call_impl(self, *args, **kwargs): │
│ 1521 │ │ forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.fo │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1568 in _call_impl │
│ │
│ 1565 │ │ │ │ bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hoo │
│ 1566 │ │ │ │ args = bw_hook.setup_input_hook(args) │
│ 1567 │ │ │ │
│ ❱ 1568 │ │ │ result = forward_call(*args, **kwargs) │
│ 1569 │ │ │ if _global_forward_hooks or self._forward_hooks: │
│ 1570 │ │ │ │ for hook_id, hook in ( │
│ 1571 │ │ │ │ │ *_global_forward_hooks.items(), │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/transformers/models/stable │
│ lm/modeling_stablelm.py:535 in forward │
│ │
│ 532 │ │ │
│ 533 │ │ bsz, q_len, _ = hidden_states.size() │
│ 534 │ │ │
│ ❱ 535 │ │ query_states = self.q_proj(hidden_states) │
│ 536 │ │ key_states = self.k_proj(hidden_states) │
│ 537 │ │ value_states = self.v_proj(hidden_states) │
│ 538 │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1518 in _wrapped_call_impl │
│ │
│ 1515 │ │ if self._compiled_call_impl is not None: │
│ 1516 │ │ │ return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] │
│ 1517 │ │ else: │
│ ❱ 1518 │ │ │ return self._call_impl(*args, **kwargs) │
│ 1519 │ │
│ 1520 │ def _call_impl(self, *args, **kwargs): │
│ 1521 │ │ forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.fo │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1568 in _call_impl │
│ │
│ 1565 │ │ │ │ bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hoo │
│ 1566 │ │ │ │ args = bw_hook.setup_input_hook(args) │
│ 1567 │ │ │ │
│ ❱ 1568 │ │ │ result = forward_call(*args, **kwargs) │
│ 1569 │ │ │ if _global_forward_hooks or self._forward_hooks: │
│ 1570 │ │ │ │ for hook_id, hook in ( │
│ 1571 │ │ │ │ │ *_global_forward_hooks.items(), │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/peft/tuners/lora/layer.py: │
│ 509 in forward │
│ │
│ 506 │ │ │ │ x = x.to(lora_A.weight.dtype) │
│ 507 │ │ │ │ │
│ 508 │ │ │ │ if not self.use_dora[active_adapter]: │
│ ❱ 509 │ │ │ │ │ result = result + lora_B(lora_A(dropout(x))) * scaling │
│ 510 │ │ │ │ else: │
│ 511 │ │ │ │ │ x = dropout(x) │
│ 512 │ │ │ │ │ result = result + self._apply_dora(x, lora_A, lora_B, scaling, activ │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1518 in _wrapped_call_impl │
│ │
│ 1515 │ │ if self._compiled_call_impl is not None: │
│ 1516 │ │ │ return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] │
│ 1517 │ │ else: │
│ ❱ 1518 │ │ │ return self._call_impl(*args, **kwargs) │
│ 1519 │ │
│ 1520 │ def _call_impl(self, *args, **kwargs): │
│ 1521 │ │ forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.fo │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/module.py │
│ :1568 in _call_impl │
│ │
│ 1565 │ │ │ │ bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hoo │
│ 1566 │ │ │ │ args = bw_hook.setup_input_hook(args) │
│ 1567 │ │ │ │
│ ❱ 1568 │ │ │ result = forward_call(*args, **kwargs) │
│ 1569 │ │ │ if _global_forward_hooks or self._forward_hooks: │
│ 1570 │ │ │ │ for hook_id, hook in ( │
│ 1571 │ │ │ │ │ *_global_forward_hooks.items(), │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/nn/modules/linear.py │
│ :114 in forward │
│ │
│ 111 │ │ │ init.uniform_(self.bias, -bound, bound) │
│ 112 │ │
│ 113 │ def forward(self, input: Tensor) -> Tensor: │
│ ❱ 114 │ │ return F.linear(input, self.weight, self.bias) │
│ 115 │ │
│ 116 │ def extra_repr(self) -> str: │
│ 117 │ │ return f'in_features={self.in_features}, out_features={self.out_features}, bias= │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/deepspeed/runtime/zero/lin │
│ ear.py:109 in zero3_linear_wrap │
│ │
│ 106 │
│ 107 def zero3_linear_wrap(input, weight, bias=None): │
│ 108 │ if bias is None: │
│ ❱ 109 │ │ return LinearFunctionForZeroStage3.apply(input, weight) │
│ 110 │ else: │
│ 111 │ │ return LinearFunctionForZeroStage3.apply(input, weight, bias) │
│ 112 │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/autograd/function.py │
│ :539 in apply │
│ │
│ 536 │ │ if not torch._C._are_functorch_transforms_active(): │
│ 537 │ │ │ # See NOTE: [functorch vjp and autograd interaction] │
│ 538 │ │ │ args = _functorch.utils.unwrap_dead_wrappers(args) │
│ ❱ 539 │ │ │ return super().apply(*args, **kwargs) # type: ignore[misc] │
│ 540 │ │ │
│ 541 │ │ if cls.setup_context == _SingleLevelFunction.setup_context: │
│ 542 │ │ │ raise RuntimeError( │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/torch/cuda/amp/autocast_mo │
│ de.py:113 in decorate_fwd │
│ │
│ 110 │ │ args[0]._dtype = torch.get_autocast_gpu_dtype() │
│ 111 │ │ if cast_inputs is None: │
│ 112 │ │ │ args[0]._fwd_used_autocast = torch.is_autocast_enabled() │
│ ❱ 113 │ │ │ return fwd(*args, **kwargs) │
│ 114 │ │ else: │
│ 115 │ │ │ autocast_context = torch.is_autocast_enabled() │
│ 116 │ │ │ args[0]._fwd_used_autocast = False │
│ │
│ /nethome/yheng6/miniconda3/envs/handbook/lib/python3.10/site-packages/deepspeed/runtime/zero/lin │
│ ear.py:57 in forward │
│ │
│ 54 │ │ │ # fused op is marginally faster │
│ 55 │ │ │ ret = torch.addmm(bias, input, weight.t()) │
│ 56 │ │ else: │
│ ❱ 57 │ │ │ output = input.matmul(weight.t()) │
│ 58 │ │ │ if bias is not None: │
│ 59 │ │ │ │ output += bias │
│ 60 │ │ │ ret = output │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: expected mat1 and mat2 to have the same dtype, but got: c10::BFloat16 != c10::Half
Zero config:
{
"fp16": {
"enabled": false
},
"bf16": {
"enabled": true
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "none",
"nvme_path": "None"
},
"offload_param": {
"device": "none",
"nvme_path": "None"
},
"stage3_gather_16bit_weights_on_model_save": true,
"reduce_bucket_size": "auto",
"zero_quantized_weights": true,
"zero_hpz_partition_size": 2,
"zero_quantized_gradients": true,
"contiguous_gradients": true,
"overlap_comm": true
},
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": "inf"
}
Accelerate config:
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
# deepspeed_multinode_launcher: standard
# offload_optimizer_device: none
# offload_param_device: none
zero3_init_flag: true
# zero3_save_16bit_model: true
# zero_stage: 3
deepspeed_config_file: ./zero_configs/zero3++.json
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
#mixed_precision: bf16
num_machines: 1
#num_processes: 8
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
Trainer on 2 A5000 GPUs.
Hi Currently zero++ feature does not support for bf16 quantization, I suppose that is the root cause of this issue.
To fix it, you can
Either use fp16
as dtype
Or make "zero_quantized_weights": false
and zero_quantized_gradients": false
@GuanhuaWang But because of some other training stability issues like this related to initializing llama in fp16 this makes training with zero++ for llama quite troublesome. Should we maybe reopen this issue and see about supporting bf16 in zero++?
Describe the bug
zero_quantized_nontrainable_weights=True
when using PEFT+DeepSpeed with Mixed-Precision training using BF16 leads tofloat != c10::BFloat16
errorTo Reproduce Steps to reproduce the behavior:
Expected behavior When using PEFT LoRA with DeepSpeed along with the feature
zero_quantized_nontrainable_weights
, it should lead to non-trainable weights being quantized resulting in a lot of memory savings. This would enable even larger model fine-tuning or large batch sizes.ds_report output
System info (please complete the following information):
Launcher context Accelerate launcher which internally uses the DeepSpeed launcher.