microsoft / DeepSpeed

DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.
https://www.deepspeed.ai/
Apache License 2.0
34.66k stars 4.05k forks source link

[BUG] ZeRO++ is broken: `zero_quantized_weights` fails #4852

Open stas00 opened 8 months ago

stas00 commented 8 months ago

Describe the bug

Adding "zero_quantized_weights": true, leads to a crash:

35:1]:  warnings.warn(
[35:1]:Traceback (most recent call last):
[35:1]:  File "/data/env/lib/repos/retro-llama/tr042-dawn-llama-2/core/dawn/dawn/training/main.py", line 243, in <module>
[35:1]:    train_logs = trainer.train(
[35:1]:  File "/data/env/lib/repos/retro-llama/tr042-dawn-llama-2/core/dawn/dawn/training/trainer.py", line 1725, in train
[35:1]:    == 0
[35:1]:  File "/data/env/lib/repos/retro-llama/tr042-dawn-llama-2/core/dawn/dawn/training/trainer.py", line 659, in _do_batch
[35:1]:    progress_columns,
[35:1]:  File "/exp/retro-llama/tr042-dawn-llama-2/CONDA_ENV_NAME/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _
wrapped_call_impl
[35:1]:    return self._call_impl(*args, **kwargs)
[35:1]:  File "/exp/retro-llama/tr042-dawn-llama-2/CONDA_ENV_NAME/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _
call_impl
[35:1]:    return forward_call(*args, **kwargs)
[35:1]:  File "/exp/retro-llama/tr042-dawn-llama-2/CONDA_ENV_NAME/lib/python3.9/site-packages/deepspeed/utils/nvtx.py", line 15, in wrappe
d_fn
[35:1]:    ret_val = func(*args, **kwargs)
[35:1]:  File "/exp/retro-llama/tr042-dawn-llama-2/CONDA_ENV_NAME/lib/python3.9/site-packages/deepspeed/runtime/engine.py", line 1822, in 
forward
[35:1]:    loss = self.module(*inputs, **kwargs)
[35:1]:  File "/exp/retro-llama/tr042-dawn-llama-2/CONDA_ENV_NAME/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _
wrapped_call_impl
[35:1]:    return self._call_impl(*args, **kwargs)
[35:1]:  File "/exp/retro-llama/tr042-dawn-llama-2/CONDA_ENV_NAME/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1568, in _
call_impl
[35:1]:    result = forward_call(*args, **kwargs)
[35:1]:  File "/data/env/lib/repos/retro-llama/tr042-dawn-llama-2/core/dawn/dawn/models/llama/modeling_llama.py", line 1267, in forward
[35:1]:    outputs = self.model(
[35:1]:  File "/exp/retro-llama/tr042-dawn-llama-2/CONDA_ENV_NAME/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _
wrapped_call_impl
[35:1]:    return self._call_impl(*args, **kwargs)
[35:1]:  File "/exp/retro-llama/tr042-dawn-llama-2/CONDA_ENV_NAME/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1568, in _
call_impl
[35:1]:    result = forward_call(*args, **kwargs)
[35:1]:  File "/data/env/lib/repos/retro-llama/tr042-dawn-llama-2/core/dawn/dawn/models/llama/modeling_llama.py", line 1131, in forward
[35:1]:    layer_outputs = self._gradient_checkpointing_func(
[35:1]:  File "/exp/retro-llama/tr042-dawn-llama-2/CONDA_ENV_NAME/lib/python3.9/site-packages/torch/_compile.py", line 24, in inner
[35:1]:    return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
[35:1]:  File "/exp/retro-llama/tr042-dawn-llama-2/CONDA_ENV_NAME/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 328, in _
fn
[35:1]:    return fn(*args, **kwargs)
[35:1]:  File "/exp/retro-llama/tr042-dawn-llama-2/CONDA_ENV_NAME/lib/python3.9/site-packages/torch/_dynamo/external_utils.py", line 17, i
n inner
[35:1]:    return fn(*args, **kwargs)
[35:1]:  File "/exp/retro-llama/tr042-dawn-llama-2/CONDA_ENV_NAME/lib/python3.9/site-packages/torch/utils/checkpoint.py", line 451, in che
ckpoint
[35:1]:    return CheckpointFunction.apply(function, preserve, *args)
[35:1]:  File "/exp/retro-llama/tr042-dawn-llama-2/CONDA_ENV_NAME/lib/python3.9/site-packages/torch/autograd/function.py", line 539, in ap
ply
[35:1]:    return super().apply(*args, **kwargs)  # type: ignore[misc]
[35:1]:  File "/exp/retro-llama/tr042-dawn-llama-2/CONDA_ENV_NAME/lib/python3.9/site-packages/torch/utils/checkpoint.py", line 230, in for
ward
[35:1]:    outputs = run_function(*args)
[35:1]:  File "/exp/retro-llama/tr042-dawn-llama-2/CONDA_ENV_NAME/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _
wrapped_call_impl
[35:1]:    return self._call_impl(*args, **kwargs)
[35:1]:  File "/exp/retro-llama/tr042-dawn-llama-2/CONDA_ENV_NAME/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1568, in _
call_impl
[35:1]:    result = forward_call(*args, **kwargs)
[35:1]:  File "/data/env/lib/repos/retro-llama/tr042-dawn-llama-2/core/dawn/dawn/models/llama/modeling_llama.py", line 826, in forward
[35:1]:    hidden_states, self_attn_weights, present_key_value = self.self_attn(
[35:1]:  File "/exp/retro-llama/tr042-dawn-llama-2/CONDA_ENV_NAME/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _
wrapped_call_impl
[35:1]:    return self._call_impl(*args, **kwargs)
[35:1]:  File "/exp/retro-llama/tr042-dawn-llama-2/CONDA_ENV_NAME/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1568, in _
call_impl
[35:1]:    result = forward_call(*args, **kwargs)
[35:1]:  File "/data/env/lib/repos/retro-llama/tr042-dawn-llama-2/core/dawn/dawn/models/llama/modeling_llama.py", line 454, in forward
[35:1]:    query_states = self.q_proj(hidden_states)
[35:1]:  File "/exp/retro-llama/tr042-dawn-llama-2/CONDA_ENV_NAME/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _
wrapped_call_impl
[35:1]:    return self._call_impl(*args, **kwargs)
[35:1]:  File "/exp/retro-llama/tr042-dawn-llama-2/CONDA_ENV_NAME/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1568, in _
call_impl
[35:1]:    result = forward_call(*args, **kwargs)
[35:1]:  File "/exp/retro-llama/tr042-dawn-llama-2/CONDA_ENV_NAME/lib/python3.9/site-packages/torch/nn/modules/linear.py", line 114, in fo
rward
[35:1]:    return F.linear(input, self.weight, self.bias)
[35:1]:  File "/exp/retro-llama/tr042-dawn-llama-2/CONDA_ENV_NAME/lib/python3.9/site-packages/deepspeed/runtime/zero/linear.py", line 109,
in zero3_linear_wrap
[35:1]:    return LinearFunctionForZeroStage3.apply(input, weight)
[35:1]:  File "/exp/retro-llama/tr042-dawn-llama-2/CONDA_ENV_NAME/lib/python3.9/site-packages/torch/autograd/function.py", line 539, in ap
ply
[35:1]:    return super().apply(*args, **kwargs)  # type: ignore[misc]
[35:1]:  File "/exp/retro-llama/tr042-dawn-llama-2/CONDA_ENV_NAME/lib/python3.9/site-packages/torch/cuda/amp/autocast_mode.py", line 113, 
in decorate_fwd
[35:1]:    return fwd(*args, **kwargs)
[35:1]:  File "/exp/retro-llama/tr042-dawn-llama-2/CONDA_ENV_NAME/lib/python3.9/site-packages/deepspeed/runtime/zero/linear.py", line 57, 
in forward
[35:1]:    output = input.matmul(weight.t())
[35:1]:RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::Half

config:

{
    "bf16": {
        "enabled": true
    },
    "zero_optimization": {
        "stage": 3,
        "overlap_comm": true,
        "reduce_scatter": true,
        "reduce_bucket_size": "auto",
        "contiguous_gradients": true,
        "stage3_gather_16bit_weights_on_model_save": false,
        "stage3_prefetch_bucket_size": "auto",
        "stage3_param_persistence_threshold": "auto",
        "stage3_max_live_parameters": 1e9,
        "stage3_max_reuse_distance": 1e9,

        "zero_quantized_weights": true,

        "offload_optimizer": {
            "device": "none"
        },
        "offload_param": {
            "device": "none"
        }
    },
    "gradient_clipping": "auto",
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "steps_per_print": 2000000
}

ds_report output

DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-dev package with apt
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
fused_adam ............. [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_lion ............... [NO] ....... [OKAY]
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
evoformer_attn ......... [NO] ....... [NO]
fused_lamb ............. [NO] ....... [OKAY]
fused_lion ............. [NO] ....... [OKAY]
inference_core_ops ..... [NO] ....... [OKAY]
cutlass_ops ............ [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
ragged_device_ops ...... [NO] ....... [OKAY]
ragged_ops ............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.1
 [WARNING]  using untested triton version (2.1.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/env/lib/conda/tr041-dawn-llama/lib/python3.9/site-packages/torch']
torch version .................... 2.1.1+cu121
deepspeed install path ........... ['/data/env/lib/repos/retro-llama/tr043-dawn-llama-3/DeepSpeed/deepspeed']
deepspeed info ................... 0.12.6+48ddf31d, 48ddf31d, HeyangQin/mixz_hpz_fix
torch cuda version ............... 12.1
torch hip version ................ None
nvcc version ..................... 11.8
deepspeed wheel compiled w. ...... torch 2.1, cuda 12.1
shared memory (/dev/shm) size .... 669.32 GB

Screenshots If applicable, add screenshots to help explain your problem.

System info (please complete the following information):

@HeyangQin

pacman100 commented 8 months ago

Yes, I am experiencing the same issue.

GuanhuaWang commented 8 months ago

Hi @HeyangQin , I believe we don't support bf16 training, which cause the RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::Half

ys950902 commented 7 months ago

Hi @GuanhuaWang, thanks for your explanation, for bf16 training on zero++ not supporting, is it because for quantize kernel on zero_quantized_weights only supports fp16/fp32, or some further accuracy issue with bf16 datatype?

jacklanda commented 4 months ago

The same issue.

sam-h-bean commented 3 months ago

Same, llama doesn't like continuing training in fp16 so would be great to have native support for bf16. Also should we expect to be able to run FP8 training w/ zero++ @GuanhuaWang?