microsoft / DeepSpeedExamples

Example models using DeepSpeed
Apache License 2.0
6.09k stars 1.04k forks source link

Why ZeRO-2 use more CUDA Memory than ZeRO-1? #123

Open dancingpipi opened 3 years ago

dancingpipi commented 3 years ago

Follow the bing_bert tutorial, my deepspeed_config is:

{
  "train_batch_size": 4096,
  "train_micro_batch_size_per_gpu": 32,
  "steps_per_print": 1000,
  "prescale_gradients": false,
  "optimizer": {
    "type": "Adam",
    "params": {
      "lr": 6e-3,
      "betas": [
        0.9,
        0.99
      ],
      "eps": 1e-8,
      "weight_decay": 0.01
    }
  },

  "zero_optimization": {
    "stage": 1,
    "allgather_partitions": true,
    "allgather_bucket_size": 5e8,
    "overlap_comm": false,
    "reduce_scatter": true,
    "reduce_bucket_size": 5e8,
    "contiguous_gradients": true,
    "grad_hooks": true,
    "round_robin_gradients": false
  },

  "scheduler": {
    "type": "WarmupLR",
    "params": {
        "warmup_min_lr": 1e-8,
        "warmup_max_lr": 6e-3
    }
  },
  "gradient_clipping": 1.0,

  "wall_clock_breakdown": false,

  "fp16": {
    "enabled": true,
    "loss_scale": 0
  },
  "sparse_attention": {
    "mode": "fixed",
    "block": 16,
    "different_layout_per_head": true,
    "num_local_blocks": 4,
    "num_global_blocks": 1,
    "attention": "bidirectional",
    "horizontal_global_attention": false,
    "num_different_global_patterns": 4
  }
}

The CUDA Memory usage for stage 1 is 8900MB per GPU The CUDA Memory usage for stage 2 is 9600MB per GPU

And the ZeRO-2 is much slower than ZeRO-1 in training speed.

Any help will be appreciate~

tjruwase commented 2 years ago

@dancingpipi, apologies for the delayed response. Hope the answers below are still helpful.

  1. ZeRO is designed to reduce the memory overheads of very large models, with billions of parameters. ZeRO achieves this by extra communication and fixed-size computation buffers, the costs of which are amortized by the large computation of massive models. However, for small models like Bert (~300M parameters), these ZeRO features introduce slowdowns and excess memory consumption relative to the baseline case.

  2. Also, higher stages of ZeRO are able to fit larger models than lower stages by incurring extra communication and computation buffer costs. Thus, for large model training, it is best to the lowest ZeRO stage fits the model (and batch size).