kvcache-ai / ktransformers

A Flexible Framework for Experiencing Cutting-edge LLM Inference Optimizations
Apache License 2.0
741 stars 39 forks source link

8-GPU configuration on L40 OOM #76

Closed fengyang95 closed 2 months ago

fengyang95 commented 2 months ago

I tried running deepseek-v2 with an 8xL40 46G configuration, but I encountered a GPU memory Out-of-Memory (OOM) error. Why would such a large amount of GPU memory still lead to an OOM issue?

- match:
    name: "^model.embed_tokens"
  replace:
    class: "default"
    kwargs:
        generate_device: "cuda:0"
        prefill_device: "cuda:0"

- match:
    name: "^model\\.layers\\.([0-3])\\."
    class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
  replace:
    class: ktransformers.operators.RoPE.YarnRotaryEmbedding
    kwargs:
      generate_device: "cuda:0"
      prefill_device: "cuda:0"
- match:
    name: "^model\\.layers\\.([4-9]|[1][0-1])\\."
    class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
  replace:
    class: ktransformers.operators.RoPE.YarnRotaryEmbedding
    kwargs:
      generate_device: "cuda:1"
      prefill_device: "cuda:1"
- match:
    name: "^model\\.layers\\.([1][2-9])\\."
    class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
  replace:
    class: ktransformers.operators.RoPE.YarnRotaryEmbedding
    kwargs:
      generate_device: "cuda:2"
      prefill_device: "cuda:2"
- match:
    name: "^model\\.layers\\.([2][0-7])\\."
    class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
  replace:
    class: ktransformers.operators.RoPE.YarnRotaryEmbedding
    kwargs:
      generate_device: "cuda:3"
      prefill_device: "cuda:3"
- match:
    name: "^model\\.layers\\.([2][8-9]|[3][0-5])\\."
    class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
  replace:
    class: ktransformers.operators.RoPE.YarnRotaryEmbedding
    kwargs:
      generate_device: "cuda:4"
      prefill_device: "cuda:4"
- match:
    name: "^model\\.layers\\.([3][6-9]|[4][0-3])\\."
    class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
  replace:
    class: ktransformers.operators.RoPE.YarnRotaryEmbedding
    kwargs:
      generate_device: "cuda:5"
      prefill_device: "cuda:5"
- match:
    name: "^model\\.layers\\.([4][4-9]|[5][0-1])\\."
    class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
  replace:
    class: ktransformers.operators.RoPE.YarnRotaryEmbedding
    kwargs:
      generate_device: "cuda:6"
      prefill_device: "cuda:6"
- match:
    name: "^model\\.layers\\.([5][2-9])\\."
    class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
  replace:
    class: ktransformers.operators.RoPE.YarnRotaryEmbedding
    kwargs:
      generate_device: "cuda:7"
      prefill_device: "cuda:7"

- match:
    name: "^model\\.layers\\.([0-3])\\.(?!self_attn\\.kv_b_proj).*$"  # regular expression
    class: torch.nn.Linear  # only match modules matching name and class simultaneously
  replace:
    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types
    kwargs:
      generate_device: "cuda:0"
      prefill_device: "cuda:0"
      generate_op: "KLinearMarlin"
      prefill_op: "KLinearTorch"
- match:
    name: "^model\\.layers\\.([4-9]|[1][0-1])\\.(?!self_attn\\.kv_b_proj).*$"  # regular expression
    class: torch.nn.Linear  # only match modules matching name and class simultaneously
  replace:
    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types
    kwargs:
      generate_device: "cuda:1"
      prefill_device: "cuda:1"
      generate_op: "KLinearMarlin"
      prefill_op: "KLinearTorch"
- match:
    name: "^model\\.layers\\.([1][2-9])\\.(?!self_attn\\.kv_b_proj).*$"  # regular expression
    class: torch.nn.Linear  # only match modules matching name and class simultaneously
  replace:
    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types
    kwargs:
      generate_device: "cuda:2"
      prefill_device: "cuda:2"
      generate_op: "KLinearMarlin"
      prefill_op: "KLinearTorch"
- match:
    name: "^model\\.layers\\.([2][0-7])\\.(?!self_attn\\.kv_b_proj).*$"  # regular expression
    class: torch.nn.Linear  # only match modules matching name and class simultaneously
  replace:
    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types
    kwargs:
      generate_device: "cuda:3"
      prefill_device: "cuda:3"
      generate_op: "KLinearMarlin"
      prefill_op: "KLinearTorch"
- match:
    name: "^model\\.layers\\.([2][8-9]|[3][0-5])\\.(?!self_attn\\.kv_b_proj).*$"  # regular expression
    class: torch.nn.Linear  # only match modules matching name and class simultaneously
  replace:
    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types
    kwargs:
      generate_device: "cuda:4"
      prefill_device: "cuda:4"
      generate_op: "KLinearMarlin"
      prefill_op: "KLinearTorch"
- match:
    name: "^model\\.layers\\.([3][6-9]|[4][0-3])\\.(?!self_attn\\.kv_b_proj).*$"  # regular expression
    class: torch.nn.Linear  # only match modules matching name and class simultaneously
  replace:
    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types
    kwargs:
      generate_device: "cuda:5"
      prefill_device: "cuda:5"
      generate_op: "KLinearMarlin"
      prefill_op: "KLinearTorch"
- match:
    name: "^model\\.layers\\.([4][4-9]|[5][0-1])\\.(?!self_attn\\.kv_b_proj).*$"  # regular expression
    class: torch.nn.Linear  # only match modules matching name and class simultaneously
  replace:
    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types
    kwargs:
      generate_device: "cuda:6"
      prefill_device: "cuda:6"
      generate_op: "KLinearMarlin"
      prefill_op: "KLinearTorch"
- match:
    name: "^model\\.layers\\.([5][2-9])\\.(?!self_attn\\.kv_b_proj).*$"  # regular expression
    class: torch.nn.Linear  # only match modules matching name and class simultaneously
  replace:
    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types
    kwargs:
      generate_device: "cuda:7"
      prefill_device: "cuda:7"
      generate_op: "KLinearMarlin"
      prefill_op: "KLinearTorch"

- match:
    name: "^model\\.layers\\.([0-3])\\.mlp$"
    class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
  replace:
    class: ktransformers.operators.experts.KDeepseekV2MoE     # mlp module with custom forward function
    kwargs:
      generate_device: "cuda:0"
      prefill_device: "cuda:0"
- match:
    name: "^model\\.layers\\.([4-9]|[1][0-1])\\.mlp$"
    class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
  replace:
    class: ktransformers.operators.experts.KDeepseekV2MoE     # mlp module with custom forward function
    kwargs:
      generate_device: "cuda:1"
      prefill_device: "cuda:1"
- match:
    name: "^model\\.layers\\.([1][2-9])\\.mlp$"
    class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
  replace:
    class: ktransformers.operators.experts.KDeepseekV2MoE     # mlp module with custom forward function
    kwargs:
      generate_device: "cuda:2"
      prefill_device: "cuda:2"
- match:
    name: "^model\\.layers\\.([2][0-7])\\.mlp$"
    class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
  replace:
    class: ktransformers.operators.experts.KDeepseekV2MoE     # mlp module with custom forward function
    kwargs:
      generate_device: "cuda:3"
      prefill_device: "cuda:3"

- match:
    name: "^model\\.layers\\.([2][8-9]|[3][0-5])\\.mlp$"
    class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
  replace:
    class: ktransformers.operators.experts.KDeepseekV2MoE     # mlp module with custom forward function
    kwargs:
      generate_device: "cuda:4"
      prefill_device: "cuda:4"
- match:
    name: "^model\\.layers\\.([3][6-9]|[4][0-3])\\.mlp$"
    class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
  replace:
    class: ktransformers.operators.experts.KDeepseekV2MoE     # mlp module with custom forward function
    kwargs:
      generate_device: "cuda:5"
      prefill_device: "cuda:5"
- match:
    name: "^model\\.layers\\.([4][4-9]|[5][0-1])\\.mlp$"
    class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
  replace:
    class: ktransformers.operators.experts.KDeepseekV2MoE     # mlp module with custom forward function
    kwargs:
      generate_device: "cuda:6"
      prefill_device: "cuda:6"
- match:
    name: "^model\\.layers\\.([5][2-9])\\.mlp$"
    class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
  replace:
    class: ktransformers.operators.experts.KDeepseekV2MoE     # mlp module with custom forward function
    kwargs:
      generate_device: "cuda:7"
      prefill_device: "cuda:7"

- match:
    name: "^model\\.layers\\.([0-3])\\.mlp\\.experts$"
  replace:
    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism
    kwargs:
      prefill_device: "cuda:0"
      prefill_op: "KExpertsTorch"
      generate_device: "cuda:0"
      generate_op:  "KExpertsTorch"
      out_device: "cuda:0"
  recursive: False # don't recursively inject submodules of this module
- match:
    name: "^model\\.layers\\.([4-9]|[1][0-1])\\.mlp\\.experts$"
  replace:
    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism
    kwargs:
      prefill_device: "cuda:1"
      prefill_op: "KExpertsTorch"
      generate_device: "cuda:1"
      generate_op:  "KExpertsTorch"
      out_device: "cuda:1"
  recursive: False # don't recursively inject submodules of this module
- match:
    name: "^model\\.layers\\.([1][2-9])\\.mlp\\.experts$"
  replace:
    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism
    kwargs:
      prefill_device: "cuda:2"
      prefill_op: "KExpertsTorch"
      generate_device: "cuda:2"
      generate_op:  "KExpertsTorch"
      out_device: "cuda:2"
  recursive: False # don't recursively inject submodules of this module
- match:
    name: "^model\\.layers\\.([2][0-7])\\.mlp\\.experts$"
  replace:
    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism
    kwargs:
      prefill_device: "cuda:3"
      prefill_op: "KExpertsTorch"
      generate_device: "cuda:3"
      generate_op:  "KExpertsTorch"
      out_device: "cuda:3"
  recursive: False # don't recursively inject submodules of this module
- match:
    name: "^model\\.layers\\.([2][8-9]|[3][0-5])\\.mlp\\.experts$"
  replace:
    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism
    kwargs:
      prefill_device: "cuda:4"
      prefill_op: "KExpertsTorch"
      generate_device: "cuda:4"
      generate_op:  "KExpertsTorch"
      out_device: "cuda:4"
  recursive: False # don't recursively inject submodules of this module
- match:
    name: "^model\\.layers\\.([3][6-9]|[4][0-3])\\.mlp\\.experts$"
  replace:
    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism
    kwargs:
      prefill_device: "cuda:5"
      prefill_op: "KExpertsTorch"
      generate_device: "cuda:5"
      generate_op:  "KExpertsTorch"
      out_device: "cuda:5"
  recursive: False # don't recursively inject submodules of this module
- match:
    name: "^model\\.layers\\.([4][4-9]|[5][0-1])\\.mlp\\.experts$"
  replace:
    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism
    kwargs:
      prefill_device: "cuda:6"
      prefill_op: "KExpertsTorch"
      generate_device: "cuda:6"
      generate_op:  "KExpertsTorch"
      out_device: "cuda:6"
  recursive: False # don't recursively inject submodules of this module
- match:
    name: "^model\\.layers\\.([5][2-9])\\.mlp\\.experts$"
  replace:
    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism
    kwargs:
      prefill_device: "cuda:7"
      prefill_op: "KExpertsTorch"
      generate_device: "cuda:7"
      generate_op:  "KExpertsTorch"
      out_device: "cuda:7"
  recursive: False # don't recursively inject submodules of this module

- match:
    name: "^model\\.layers\\.([0-3])\\.self_attn$"
  replace:
    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
    kwargs:
      generate_device: "cuda:0"
      prefill_device: "cuda:0"
- match:
    name: "^model\\.layers\\.([4-9]|[1][0-1])\\.self_attn$"
  replace:
    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
    kwargs:
      generate_device: "cuda:1"
      prefill_device: "cuda:1"
- match:
    name: "^model\\.layers\\.([1][2-9])\\.self_attn$"
  replace:
    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
    kwargs:
      generate_device: "cuda:2"
      prefill_device: "cuda:2"
- match:
    name: "^model\\.layers\\.([2][0-7])\\.self_attn$"
  replace:
    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
    kwargs:
      generate_device: "cuda:3"
      prefill_device: "cuda:3"

- match:
    name: "^model\\.layers\\.([2][8-9]|[3][0-5])\\.self_attn$"
  replace:
    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
    kwargs:
      generate_device: "cuda:4"
      prefill_device: "cuda:4"
- match:
    name: "^model\\.layers\\.([3][6-9]|[4][0-3])\\.self_attn$"
  replace:
    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
    kwargs:
      generate_device: "cuda:5"
      prefill_device: "cuda:5"
- match:
    name: "^model\\.layers\\.([4][4-9]|[5][0-1])\\.self_attn$"
  replace:
    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
    kwargs:
      generate_device: "cuda:6"
      prefill_device: "cuda:6"
- match:
    name: "^model\\.layers\\.([5][2-9])\\.self_attn$"
  replace:
    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
    kwargs:
      generate_device: "cuda:7"
      prefill_device: "cuda:7"

- match:
    name: "^model$"
  replace:
    class: "ktransformers.operators.models.KDeepseekV2Model"
    kwargs:
      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
      transfer_map: 
        4: "cuda:1"
        12: "cuda:2"
        20: "cuda:3"
        28: "cuda:4"
        36: "cuda:5"
        44: "cuda:6"
        52: "cuda:7"

- match:
    name: "^model\\.layers\\.([0-3])\\."
  replace:
    class: "default"
    kwargs:
      generate_device: "cuda:0"
      prefill_device: "cuda:0"
- match:
    name: "(^model\\.layers\\.([4-9]|[1][0-1])\\.)"
  replace:
    class: "default"
    kwargs:
      generate_device: "cuda:1"
      prefill_device: "cuda:1"
- match:
    name: "(^model\\.layers\\.([1][2-9])\\.)"
  replace:
    class: "default"
    kwargs:
      generate_device: "cuda:2"
      prefill_device: "cuda:2"
- match:
    name: "(^model\\.layers\\.([2][0-7])\\.)"
  replace:
    class: "default"
    kwargs:
      generate_device: "cuda:3"
      prefill_device: "cuda:3"

- match:
    name: "(^model\\.layers\\.([2][8-9]|[3][0-5])\\.)"
  replace:
    class: "default"
    kwargs:
      generate_device: "cuda:4"
      prefill_device: "cuda:4"
- match:
    name: "(^model\\.layers\\.([3][6-9]|[4][0-3])\\.)"
  replace:
    class: "default"
    kwargs:
      generate_device: "cuda:5"
      prefill_device: "cuda:5"
- match:
    name: "(^model\\.layers\\.([4][4-9]|[5][0-1])\\.)"
  replace:
    class: "default"
    kwargs:
      generate_device: "cuda:6"
      prefill_device: "cuda:6"
- match:
    name: "(^model\\.layers\\.([5][2-9])\\.)|(^model.norm)|(^lm_head)"
  replace:
    class: "default"
    kwargs:
      generate_device: "cuda:7"
      prefill_device: "cuda:7"
Azure-Tang commented 2 months ago

Ok, if you want to put all modules to gpu, then your required VRAM per layer is approximately 5.2G. So 9 layers on one GPU may cause oom. You can offload one layer of experts to CPU

fengyang95 commented 2 months ago

Ok, if you want to put all modules to gpu, then your required VRAM per layer is approximately 5.2G. So 9 layers on one GPU may cause oom. You can offload one layer of experts to CPU @Azure-Tang But I only put 8 layers on each GPU, why is it still OOM? Do I need to reserve some space?

Azure-Tang commented 2 months ago

The 5.2G figure is an approximation, and the actual value may be slightly higher. As a result, 8 layers per GPU could be the upper limit. You can consider offloading one layer’s experts to the CPU for each GPU. This way, each CPU will handle 8 layers of experts modules, with the remaining layers calculated in the GPU.

fengyang95 commented 2 months ago

The 5.2G figure is an approximation, and the actual value may be slightly higher. As a result, 8 layers per GPU could be the upper limit. You can consider offloading one layer’s experts to the CPU for each GPU. This way, each CPU will handle 8 layers of experts modules, with the remaining layers calculated in the GPU.

Do you mean to allocate 7 layers to each GPU and offload 1 layer to the CPU?

Azure-Tang commented 2 months ago

Not a whole layer to the CPU, only offload experts module in one decoder layer for each 8-layer GPU seems enough.

fengyang95 commented 2 months ago

Not a whole layer to the CPU, only offload experts module in one decoder layer for each 8-layer GPU seems enough.

I didn't quite understand your point. Could you pls demonstrate it using a configuration?

Azure-Tang commented 2 months ago

For example, you got 8 layers for each GPU, which cause oom. And your 8 layers' parameter distribution is like this:

image

Then you offload one layer's experts module to cpu (which we have done in our example yamls), so the layer will be like this:

image

And keep other 7 whole layers on GPU, which like this:

image

Which will offload some parameters to you ram. If it still oom, you can do this for another layer.

Azure-Tang commented 2 months ago

Hi, I’ll be closing this topic as there hasn’t been any response~