microsoft / DeepSpeed

DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.
https://www.deepspeed.ai/
Apache License 2.0
35.51k stars 4.13k forks source link

[BUG] Deepspeed repeatedly requests to lock files and gets stuck #5205

Closed pengfei-luo closed 7 months ago

pengfei-luo commented 8 months ago

Describe the bug When training with deepspeed and taking the ZeRO2 configuration, the program got stuck and did not exit after training. I had to use ctrl+c to end the process.

It showed the process stuck at File "/data2/pfluo/micromamba/envs/torch220/lib/python3.10/site-packages/deepspeed/ops/transformer/inference/triton/matmul_ext.py", line 66, in put with FileLock(self.lock_path):.

I printed the log of the filelock, and found that in repeated attempts to get the lock of the file.

2024-02-27 15:56:22,424 - /data2/pfluo/micromamba/envs/torch220/lib/python3.10/site-packages/filelock/_api.py[line:254] - DEBUG: Attempting to acquire lock 140192009925296 on /data2/pfluo/.triton/autotune/Fp16Matmul_2d_kernel.pickle.lock
2024-02-27 15:56:22,425 - /data2/pfluo/micromamba/envs/torch220/lib/python3.10/site-packages/filelock/_api.py[line:254] - DEBUG: Attempting to acquire lock 140233385915056 on /data2/pfluo/.triton/autotune/Fp16Matmul_2d_kernel.pickle.lock
2024-02-27 15:56:22,427 - /data2/pfluo/micromamba/envs/torch220/lib/python3.10/site-packages/filelock/_api.py[line:254] - DEBUG: Attempting to acquire lock 140592852660912 on /data2/pfluo/.triton/autotune/Fp16Matmul_2d_kernel.pickle.lock
2024-02-27 15:56:22,429 - /data2/pfluo/micromamba/envs/torch220/lib/python3.10/site-packages/filelock/_api.py[line:254] - DEBUG: Attempting to acquire lock 140530144834224 on /data2/pfluo/.triton/autotune/Fp16Matmul_2d_kernel.pickle.lock
2024-02-27 15:56:22,431 - /data2/pfluo/micromamba/envs/torch220/lib/python3.10/site-packages/filelock/_api.py[line:266] - DEBUG: Lock 140192009925296 not acquired on /data2/pfluo/.triton/autotune/Fp16Matmul_2d_kernel.pickle.lock, waiting 0.05 seconds ...
2024-02-27 15:56:22,433 - /data2/pfluo/micromamba/envs/torch220/lib/python3.10/site-packages/filelock/_api.py[line:266] - DEBUG: Lock 140233385915056 not acquired on /data2/pfluo/.triton/autotune/Fp16Matmul_2d_kernel.pickle.lock, waiting 0.05 seconds ...
2024-02-27 15:56:22,434 - /data2/pfluo/micromamba/envs/torch220/lib/python3.10/site-packages/filelock/_api.py[line:266] - DEBUG: Lock 140592852660912 not acquired on /data2/pfluo/.triton/autotune/Fp16Matmul_2d_kernel.pickle.lock, waiting 0.05 seconds ...
2024-02-27 15:56:22,436 - /data2/pfluo/micromamba/envs/torch220/lib/python3.10/site-packages/filelock/_api.py[line:266] - DEBUG: Lock 140530144834224 not acquired on /data2/pfluo/.triton/autotune/Fp16Matmul_2d_kernel.pickle.lock, waiting 0.05 seconds .

I modified the file deepspeed/ops/transformer/inference/triton/matmul_ext.py by commenting out lines 66-69, and the program terminated properly.

https://github.com/microsoft/DeepSpeed/blob/aed599b4422b1cdf7397abb05a58c3726523a333/deepspeed/ops/transformer/inference/triton/matmul_ext.py#L66-L69

And I found that the process also got stuck while executing ds_report. I'm not sure that such a solution has any impact on the training process. Also, I'm using an NFS filesystem, and I'm not sure if that has an effect on filelock, which in turn triggers this error.

To Reproduce Steps to reproduce the behavior:

  1. Clone the project https://github.com/BAAI-DCAI/Bunny
  2. Finetune the model with zero2 configration
    {
    "fp16": {
        "enabled": "auto",
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },
    "bf16": {
        "enabled": "auto"
    },
    "train_micro_batch_size_per_gpu": "auto",
    "train_batch_size": "auto",
    "gradient_accumulation_steps": "auto",
    "zero_optimization": {
        "stage": 2,
        "overlap_comm": true,
        "contiguous_gradients": true,
        "sub_group_size": 1e9,
        "reduce_bucket_size": "auto"
    }
    }
  3. After the training is finished, the processes get stuck.

Expected behavior The training procedure should terminate normally.

ds_report output

[2024-02-27 16:50:14,887] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
async_io ............... [NO] ....... [OKAY]
fused_adam ............. [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_lion ............... [NO] ....... [OKAY]
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
evoformer_attn ......... [NO] ....... [NO]
fused_lamb ............. [NO] ....... [OKAY]
fused_lion ............. [NO] ....... [OKAY]
inference_core_ops ..... [NO] ....... [OKAY]
cutlass_ops ............ [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
ragged_device_ops ...... [NO] ....... [OKAY]
ragged_ops ............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.2
 [WARNING]  using untested triton version (2.2.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/data2/pfluo/micromamba/envs/torch220/lib/python3.10/site-packages/torch']
torch version .................... 2.2.0+cu118
deepspeed install path ........... ['/data2/pfluo/micromamba/envs/torch220/lib/python3.10/site-packages/deepspeed']
deepspeed info ................... 0.13.2, unknown, unknown
torch cuda version ............... 11.8
torch hip version ................ None
nvcc version ..................... 11.8
deepspeed wheel compiled w. ...... torch 2.2, cuda 11.8
shared memory (/dev/shm) size .... 503.68 GB

Screenshots If applicable, add screenshots to help explain your problem.

image image

System info (please complete the following information):

Launcher context I tried both deepspeed and torchrun.

zhxieml commented 7 months ago

I have the same issue. Running rm -rf ~/.triton resolves it for me.

HeyangQin commented 7 months ago

Hello @pengfei-luo @fffffarmer, this is because deepspeed will save triton autotune cache when it exits, and given your home dir is on a NFS, such saving could be slow. The workaround would be to set the TRITON_CACHE_DIR environment variable to point to your local hard disk.

If the autotune cache is disabled, like the way you comment it out (or delete the cache), it will not impact the quality of the training but is likely going to impose a 1-3 minutes delay when you start/resume the training as triton will perform autotune from scratch.

We plan to add a warning to explain this to the user when we detect the current cache dir is on a different file system.

Ref: https://github.com/microsoft/DeepSpeed/blame/4520edd61ce1235f87d062aa075cff412ae11a73/deepspeed/ops/transformer/inference/triton/matmul_ext.py#L53C42-L53C58