microsoft / DeepSpeed

DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.
https://www.deepspeed.ai/
Apache License 2.0
35.29k stars 4.09k forks source link

[BUG] Assertion error with HF gradient checkpointing option (param already assigned swap buffer) #4072

Open jhm0104666 opened 1 year ago

jhm0104666 commented 1 year ago

Describe the bug

This is a forked issue requested by @tjruwase from: https://github.com/microsoft/DeepSpeed/issues/4047#issuecomment-1657086291

The activation checkpointing option of DeepSpeed-Infinity is not available when using HF models (only Megatron-DeepSpeed supports the full set of the DeepSpeed-Infinity options). As a workaround, I directly gave "--gradient_checkpointing" option to HF argument parser. However, I ran into "AssertionError: param 773 already assigned swap buffer id 47". The error didn't appear only when I disabled the following ZeRO stage3 options: stage3_max_live_parameters, stage3_max_reuse_distance, stage3_prefetch_bucket_size.

257040400-2df5217f-e729-4d71-b8ec-40720f2cb121

To Reproduce

Workload

Run script

PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:256 deepspeed --master_port 60000 --num_nodes=1 --num_gpus={NUM GPU TO USE} run_clm.py --deepspeed config.json --model_name_or_path facebook/opt-30b --dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 --do_train --fp16 --per_device_train_batch_size {BATCH SIZE} --learning_rate 2e-5 --num_train_epochs 1 --output_dir result --overwrite_output_dir --save_steps 0 --max_steps 4 --save_strategy "no"

DeepSpeed config (config.json)

{
  "train_micro_batch_size_per_gpu": "auto",
  "fp16": {
    "enabled": true
  },
  "optimizer": {
    "type": "Adam"
  },
  "zero_optimization": {
    "stage": 3,
    "offload_optimizer": {
      "device": "nvme",
      "nvme_path": "NVME_DEV_TO_BE_FILLED",
      "pin_memory": true,
      "buffer_count": 4,
      "fast_init": false
    },
    "offload_param": {
      "device": "nvme",
      "nvme_path": "NVME_DEV_TO_BE_FILLED",
      "pin_memory": true,
      "buffer_count": 50,
      "buffer_size": 1e9,
      "max_in_cpu": 1e9
    },
    "overlap_comm": true,
    "contiguous_gradients": true,
    "stage3_max_live_parameters": 1e9,
    "stage3_max_reuse_distance": 1e9,
    "reduce_bucket_size": "auto",
    "stage3_prefetch_bucket_size": "auto",
    "stage3_param_persistence_threshold": "auto"
  },
  "aio": {
    "block_size": 1048576,
    "queue_depth": 32,
    "thread_count": 8,
    "single_submit": true,
    "overlap_events": true
  },
  "activation_checkpointing": {},
  "flops_profiler": {
      "enabled": true,
      "profile_step": 1,
      "module_depth": -1,
      "top_modules": 1,
      "detailed": true,
      "output_file": null
  }
}

Expected behavior

ds_report output

DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
async_io ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
fused_adam ............. [NO] ....... [OKAY]
fused_lamb ............. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.0
 [WARNING]  using untested triton version (2.0.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
utils .................. [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/home/hamin.jang/work/Profile/LLM/profile_LLM/venv/lib/python3.10/site-packages/torch']
torch version .................... 2.0.1+cu118
deepspeed install path ........... ['/home/hamin.jang/work/Profile/LLM/profile_LLM/venv/lib/python3.10/site-packages/deepspeed']
deepspeed info ................... 0.9.2, unknown, unknown
torch cuda version ............... 11.8
torch hip version ................ None
nvcc version ..................... 11.8
deepspeed wheel compiled w. ...... torch 2.0, cuda 11.8

Screenshots

Attached above.

System info (please complete the following information):

--extra-index-url https://download.pytorch.org/whl/cu118
accelerate==0.19.0
aiohttp==3.8.4
aiosignal==1.3.1
async-timeout==4.0.2
attrs==23.1.0
certifi==2022.12.7
charset-normalizer==2.1.1
cmake==3.25.0
datasets==2.13.1
deepspeed==0.9.2
dill==0.3.6
evaluate==0.4.0
filelock==3.9.0
frozenlist==1.3.3
fsspec==2023.5.0
hjson==3.1.0
huggingface-hub==0.14.1
idna==3.4
inquirerpy==0.3.4
Jinja2==3.1.2
lit==15.0.7
MarkupSafe==2.1.2
mpmath==1.2.1
multidict==6.0.4
multiprocess==0.70.14
networkx==3.0
ninja==1.11.1
numpy==1.24.1
packaging==23.1
pandas==2.0.2
pfzy==0.3.4
Pillow==9.3.0
prompt-toolkit==3.0.38
psutil==5.9.5
py-cpuinfo==9.0.0
pyarrow==12.0.1
pydantic==1.10.8
python-dateutil==2.8.2
pytz==2023.3
PyYAML==6.0
regex==2023.5.5
requests==2.28.1
responses==0.18.0
six==1.16.0
sympy==1.11.1
tokenizers==0.13.3
torch==2.0.1+cu118
torchaudio==2.0.2+cu118
torchvision==0.15.2+cu118
tqdm==4.65.0
transformers==4.29.2
triton==2.0.0
typing_extensions==4.4.0
tzdata==2023.3
urllib3==1.26.13
wcwidth==0.2.6
xxhash==3.2.0
yarl==1.9.2

Launcher context

Are you launching your experiment with the deepspeed launcher, MPI, or something else?

deepspeed launcher

Docker context

Are you using a specific docker image that you can share?

I didn't use docker images

tjruwase commented 1 year ago

@jhm0104666, thanks!