philschmid / deep-learning-pytorch-huggingface

MIT License
618 stars 143 forks source link

OOM when finetuning FLANT5-xxl #6

Closed AndrewZhe closed 1 year ago

AndrewZhe commented 1 year ago

I use 4 A40 48G to finetune a FLANT5-XXL, following your blog "https://www.philschmid.de/fine-tune-flan-t5-deepspeed#2-fine-tune-flan-t5-xxl-using-deepspeed". However, I met OOM with batch_size = 1 per GPU.

I use fp16 and with offload off. If I turn on the offload, I could run the code and each GPU uses 22G VM.

Do you have any suggestions?

philschmid commented 1 year ago

What deepspeed config did you use? T5 is not working fp16 also.

AndrewZhe commented 1 year ago

Sorry, it is bf16.

I use ds_flan_t5_z3_config_bf16.json

{
  "bf16": {
    "enabled": "auto"
  },
  "optimizer": {
    "type": "AdamW",
    "params": {
      "lr": "auto",
      "betas": "auto",
      "eps": "auto",
      "weight_decay": "auto"
    }
  },
  "scheduler": {
    "type": "WarmupLR",
    "params": {
      "warmup_min_lr": "auto",
      "warmup_max_lr": "auto",
      "warmup_num_steps": "auto"
    }
  },
  "zero_optimization": {
    "stage": 3,
    "overlap_comm": true,
    "contiguous_gradients": true,
    "sub_group_size": 1e9,
    "reduce_bucket_size": "auto",
    "stage3_prefetch_bucket_size": "auto",
    "stage3_param_persistence_threshold": "auto",
    "stage3_max_live_parameters": 1e9,
    "stage3_max_reuse_distance": 1e9,
    "stage3_gather_16bit_weights_on_model_save": false
  },
  "gradient_accumulation_steps": "auto",
  "gradient_clipping": "auto",
  "steps_per_print": 2000,
  "train_batch_size": "auto",
  "train_micro_batch_size_per_gpu": "auto",
  "wall_clock_breakdown": false
}
philschmid commented 1 year ago

Yes, it seems the 4x A40 GBs are insufficient to load the XXL model without offloading. I only tested 8x A100 40GBs without offloading. It seems that the activations are that big that its not fitting.

AndrewZhe commented 1 year ago

Thanks! I will try to find another 4 gpus and try it.

It is really attractive that 8 * A100 40G could use batch_size = 8 per GPU, without offloading.