microsoft / DeepSpeed

DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.
https://www.deepspeed.ai/
Apache License 2.0
34.86k stars 4.05k forks source link

[BUG] `max_in_cpu` seems to be ignored? #4221

Open andre-bauer opened 1 year ago

andre-bauer commented 1 year ago

Describe the bug I evaluate OPT-66B with Zero3 and set offloading to nvme which works fine, but I also increased max_in_cpu to 100G

printed as

DeepSpeedZeroOffloadParamConfig(device='nvme', nvme_path=PosixPath('/tmp'), buffer_count=5, buffer_size=100000000, max_in_cpu=100000000000, pin_memory=True)

In float16 I would expect that up to 200GB of host memory are used but I get a usage of ~16GB of my available 300GB. While setting in this case to "cpu" instead of "nvme" I get the same behavior for models that exceed 300GB like bloom-176B Am I missing something? How can you use nvme and cpu properly.

ds_config

{
    "zero_optimization": {
        "stage": 3,
        "offload_param": {
            "device": "nvme",
            "nvme_path": "/tmp",
            "pin_memory": true,
            "buffer_count": 5,
            "buffer_size": 1e8,
            "max_in_cpu": 1e11
        }
    },
    "load_from_fp32_weights":false,
    "gradient_accumulation_steps": "auto",
    "gradient_clipping": "auto",
    "steps_per_print": 2000,
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "wall_clock_breakdown": false,
    "fp16": {
        "enabled": true
    }
}

ds_report output

--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
async_io ............... [YES] ...... [OKAY]
fused_adam ............. [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
fused_lamb ............. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.0
 [WARNING]  using untested triton version (2.0.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
 [WARNING]  On Ampere and higher architectures please use CUDA 11+
spatial_inference ...... [NO] ....... [NO]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
 [WARNING]  On Ampere and higher architectures please use CUDA 11+
transformer_inference .. [NO] ....... [NO]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/opt/pyenv-root/versions/3.9.16/lib/python3.9/site-packages/torch']
torch version .................... 2.0.1+cu118
deepspeed install path ........... ['/opt/pyenv-root/versions/3.9.16/lib/python3.9/site-packages/deepspeed']
deepspeed info ................... 0.10.1, unknown, unknown
torch cuda version ............... 11.8
torch hip version ................ None
nvcc version ..................... 10.1
deepspeed wheel compiled w. ...... torch 2.0, cuda 11.8
shared memory (/dev/shm) size .... 50.00 GB
tjruwase commented 1 year ago

The max_in_cpu flag is activated only for training, not for inference.

andre-bauer commented 1 year ago

That means if I have 300G of ram for a 301G model there is no way to offload only the 1G of params to nvme in inference 🤔 ? I have to offload the full 301G 😱 ?

tjruwase commented 1 year ago

@andre-bauer, unfortunately, you are correct. For inference, we currently don't support splitting the offload over dram and nvme. In theory, max_in_cpu could be ported over to the inference code path, but we just have not had the bandwidth to do so :(.

An alternative would be to partially offload 300G to dram, while keeping 1G in HBM. You do this by combining model_persistence_threshold and param_persistence_threshold. You should set model_persistence_threshold to model partition size to pin in HBM (e.g., 1e9) and set param_persistence_threshold greater than the largest layer size (e.g., 2e8 for opt-66b).

mimiliaogo commented 1 week ago

@tjruwase Hi, I'm a new developer here. I'm interested in contributing to this issue. Is this feature still needed and do you think this will be a good first issue? Thanks

tjruwase commented 1 week ago

@mimiliaogo, yes this would be a good and useful first issue. Thanks!