microsoft / DeepSpeed

DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.
https://www.deepspeed.ai/
Apache License 2.0
34.36k stars 4.01k forks source link

[BUG] DeepSpeedEngine did not gather parameters automatically when doing forward() with zero-3 enabled #3570

Open liyonghua0910 opened 1 year ago

liyonghua0910 commented 1 year ago

Describe the bug I was using deepspeed zero3 in a compression script. When the model was instantiated, I found that every module was injected with a post_init method, which partitioned the parameters into multiple groups and erase the original parameter. However, when the model was doing forward propagation, the parameters were not gathered again, leading to empty weights for each layer.

To Reproduce My zero configuration was simply:

"zero_optimization": {
      "stage": 3,
      "offload_optimizer": {
        "device": "cpu",
        "pin_memory": true
      },
      "offload_param": {
        "device": "cpu",
        "pin_memory": true
      },
    },

I have tried chatglm-6b and opt-350m and they led to the same issue.

Expected behavior Reported error should be something like this:

Traceback (most recent call last):
  File ".../opt-350m/run_glue_no_trainer.py", line 571, in <module>
    main()
  File ".../opt-350m/run_glue_no_trainer.py", line 529, in main
    outputs = model(**batch)
  File ".../.conda/envs/deepspeed/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File ".../deepspeed/lib/python3.9/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File ".../deepspeed/lib/python3.9/site-packages/deepspeed/runtime/engine.py", line 1675, in forward
    loss = self.module(*inputs, **kwargs)
  File ".../deepspeed/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1212, in _call_impl
    result = forward_call(*input, **kwargs)
  File ".../deepspeed/lib/python3.9/site-packages/transformers/models/opt/modeling_opt.py", line 781, in forward
    decoder_outputs = self.decoder(
  File ".../deepspeed/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1212, in _call_impl
    result = forward_call(*input, **kwargs)
  File ".../deepspeed/lib/python3.9/site-packages/transformers/models/opt/modeling_opt.py", line 631, in forward
    inputs_embeds = self.embed_tokens(input_ids)
  File ".../deepspeed/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1212, in _call_impl
    result = forward_call(*input, **kwargs)
  File ".../deepspeed/lib/python3.9/site-packages/deepspeed/compression/basic_layer.py", line 129, in forward
    out = nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type,
  File ".../deepspeed/lib/python3.9/site-packages/torch/nn/functional.py", line 2192, in embedding
    assert padding_idx < weight.size(0), "Padding_idx must be within num_embeddings"
AssertionError: Padding_idx must be within num_embeddings

I printed the padding_idx and the weight:

At forward(), padding_idx: 1, embedding layer weight:
Parameter containing:
tensor([], device='cuda:0', dtype=torch.float16, requires_grad=True)

You can see that the weight is empty.

ds_report output

--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-devel package with yum
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
fused_adam ............. [NO] ....... [OKAY]
fused_lamb ............. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
sparse_attn ............ [NO] ....... [OKAY]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
utils .................. [NO] ....... [OKAY]
--------------------------------------------------
No CUDA runtime is found, using CUDA_HOME='/mnt/lustre/share/cuda-10.1'
DeepSpeed general environment info:
torch install path ............... ['.../deepspeed/lib/python3.9/site-packages/torch']
torch version .................... 1.13.1+cu117
deepspeed install path ........... ['.../deepspeed/lib/python3.9/site-packages/deepspeed']
deepspeed info ................... 0.9.1, unknown, unknown
torch cuda version ............... 11.7
torch hip version ................ None
nvcc version ..................... 10.1
deepspeed wheel compiled w. ...... torch 0.0, cuda 0.0
senthilps8 commented 1 year ago

@Silhouette2 How did you resolve this issue?

ASR-SCI commented 1 year ago

I also encountered this problem, looking forward to a solution

dhar174 commented 10 months ago

Hi @Silhouette2 or @ASR-SCI or @senthilps8 . Has anyone been able to resolve this issue? It is very limiting.

tjruwase commented 10 months ago

@dhar174, @ASR-SCI can you please share repro details?

dhar174 commented 10 months ago

@tjruwase I'll do my best, please do ask for any clarifications or additional info.

I'm using Deepspeed with the Accelerate library, calling my script using "accelerate launch --deepspeed". In the code, I first instantiate the Accelerator object, and then apply the following DeepSpeed settings to the HfDeepSpeedConfig object:

{
            "sparse_attention": {
                "mode": "fixed",
                "block": 16,
                "different_layout_per_head": True,
                "num_local_blocks": 4,
                "num_global_blocks": 1,
                "attention": "bidirectional",
                "horizontal_global_attention": False,
                "num_different_global_patterns": 4,
            },
            "fp16": {
                "enabled": True,
                "loss_scale": 0,
                "loss_scale_window": 1000,
                "initial_scale_power": 16,
                "hysteresis": 2,
                "min_loss_scale": 1,
            },
            "amp": {"enabled": True, "opt_level": "auto"},
            "bf16": {"enabled": False},
            "zero_optimization": {
                "stage": 3,
                "offload_param": {"device": "cpu", "pin_memory": True},
                "overlap_comm": True,
                "contiguous_gradients": True,
                "allgather_bucket_size": 1e7,
                "reduce_bucket_size": 1e7,
                "stage3_prefetch_bucket_size": 1e7,
                "stage3_max_live_parameters": 5e8,
                "stage3_max_reuse_distance": 5e8,
                "stage3_param_persistence_threshold": 1e5,
                "stage3_gather_16bit_weights_on_model_save": True,
            },
            "steps_per_print": 2000,
            "sub_group_size": 5e8,
            "train_batch_size": 1,
            "train_micro_batch_size_per_gpu": 1,
            "wall_clock_breakdown": False,
        }

I then use Accelerator.prepare on the model, before using dispatch like so (replaced filepath with ... here, the real code has the true filepath):

model = dispatch_model(model,
            device_map=infer_auto_device_map(
                model,
                dtype=torch.float16,
                max_memory={0: "7GiB", "cpu": "48GiB"},
            ),
        offload_dir="/home/.../.../.../offload",
        offload_buffers=True,)

I have tried loading different models, and in different ways, to the same effect (

"python3.10/site-packages/torch/nn/functional.py", line 2192, in embedding assert padding_idx < weight.size(0), "Padding_idx must be within num_embeddings"

").

It may be worth noting that I am generating embeddings with the forward method of a headless model. However, when I use .generate() with a model that has a causal or sequential modeling head, I instead get the error:

"File "/home/darf3/buddy/experiments/llama2/lib/python3.10/site-packages/torch/nn/functional.py", line 2210, in embedding return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse) RuntimeError: 'weight' must be 2-D"

(let me know if you want the entire error message for either error). The only difference between each error (in the code) seems to only be the use of a head on the model or not, ie whether the forward pass or generate() is used.

Another thing I noted is the following part of the debug feedback from Accelerate (or deepspeed, not sure which):

"[2023-10-03 09:52:24,304] [INFO] [partition_parameters.py:347:exit] finished initializing model - num_params = 245, num_elems = 1.42B"

That just doesn't seem right to me. 245 parameters?

I am also including the YAML config file I'm using for Accelerate config:


compute_environment: LOCAL_MACHINE
debug: true
deepspeed_config:
  gradient_accumulation_steps: 1
  gradient_clipping: 1.0
  offload_optimizer_device: cpu
  offload_param_device: cpu
  zero3_init_flag: true
  zero3_save_16bit_model: true
  zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
dynamo_config:
  dynamo_backend: INDUCTOR
  dynamo_mode: reduce-overhead
  dynamo_use_dynamic: true
  dynamo_use_fullgraph: true
machine_rank: 0
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
dhar174 commented 10 months ago

Note that while I am using ZeRO-3 init, the same error seems to occur with ZeRO-2 and I have successfully utilized ZeRO-3 initialization with cpu offloading to load 7 to 13b parameter models on a single-GPU desktop. I can't determine the difference, other than perhaps 6 months of versioning differences.

Versions being used in code that produces the error:

accelerate 0.23.0 transformers_version: 4.33.2 Python: 3.10.6 Torch 2.0.1 deepspeed 0.10.3

And some others in the env, possibly less relevant:

huggingface-hub 0.17.2 numpy 1.26.0 nvidia-cublas-cu11 11.10.3.66 nvidia-cuda-cupti-cu11 11.7.101 nvidia-cuda-nvrtc-cu11 11.7.99 nvidia-cuda-runtime-cu11 11.7.99 nvidia-cudnn-cu11 8.5.0.96 nvidia-cufft-cu11 10.9.0.58 nvidia-curand-cu11 10.2.10.91 nvidia-cusolver-cu11 11.4.0.1 nvidia-cusparse-cu11 11.7.4.91 nvidia-nccl-cu11 2.14.3 nvidia-nvtx-cu11 11.7.91

tjruwase commented 10 months ago

@dhar174, thanks for sharing these. Here some clarifications

image

Apologies for the misleading message from deepspeed. This really means 245 tensors (or torch.Parameter objects) that have a total of 1.42B elements (or model parameters).

Note that while I am using ZeRO-3 init, the same error seems to occur with ZeRO-2 ZeRO-Init only works with zero-3. It will fail with any other zero stage. However, I don't see how you are combining zero-init and zero-2.

What would be really helpful is if you shared your full code (e.g., as a gist) and the exact command line. Thanks!

dhar174 commented 10 months ago

You misunderstood, I am not combining ZeRO-2 and zero-init. What I meant was, I have also tried simply using ZeRO-2 as a test to see if the error persisted (it does). My understanding is that zero-init settings are simply ignored under zero-2.

Anyway, let me see if I can get you a gist.

tjruwase commented 10 months ago

@dhar174, thanks for the clarification. Yes, zero-init settings are meant to be ignored for zero-2, but as you well know reality often deviates from intentions :). It would be great to repro the failure with zero-2 and zero-3. I look forward to getting the gist.

dhar174 commented 10 months ago

Here is the gist link, I made it 'secret' but anyone here can feel free to access it.

https://gist.github.com/dhar174/e3106fe0daef4cd0e2498b1a25033f3f

Thank you so much for any help @tjruwase !

tjruwase commented 10 months ago

Thanks, I will try to repro asap.

senthilps8 commented 10 months ago

@dhar174 I think downgrading deepspeed might have fixed this for me. It has been a while so I don't remember exactly. I can tell you the version that worked later today.

dhar174 commented 10 months ago

Any luck with the repro or updates on working version of deepspeed? (I'd really rather use the latest or close to latest version...)