Mikubill / naifu

Train generative models with pytorch lightning
MIT License
290 stars 37 forks source link

When strategy deepspeed, the key erro of zeRO will be error, and it will just crash #36

Open X-MAXXIX opened 1 week ago

X-MAXXIX commented 1 week ago

NG@T{Q JDW3%OVV{5 {04OL When we perform multi-machine multi-GPU training, we are prompted with an out-of-memory error for the GPUs. After troubleshooting, we have identified this issue. If you can fix it, we would be very grateful

Mikubill commented 1 week ago

Could you provide more information about your configuration, hardware and environment? As a common solution, reducing the batch size or accumulate step might help.

Yidhar commented 1 week ago

Could you provide more information about your configuration, hardware and environment? As a common solution, reducing the batch size or accumulate step might help. We use H100*8 and use xformers to train the two te simultaneously. The batch size is 1, and the accumulation step is 8, but the results are generally the same. Initially only 58g memory is required, but it will soon increase to more than 80g resulting in out-of-memory. The key erro error occurs when I try to optimize with deepseed ZeRo `rank4: KeyError: Parameter containing: rank4: tensor([[-0.0128, -0.0060, -0.0098, ..., 0.0004, -0.0190, -0.0070], rank4: [ 0.0079, 0.0048, 0.0157, ..., 0.0004, -0.0063, -0.0040], rank4: [-0.0032, 0.0089, -0.0058, ..., -0.0145, 0.0044, 0.0134],

rank4: [-0.0035, 0.0126, -0.0035, ..., 0.0159, -0.0176, 0.0225], rank4: [-0.0124, -0.0085, 0.0201, ..., 0.0202, 0.0069, -0.0044], rank4: [-0.0029, -0.0132, -0.0124, ..., 0.0118, -0.0026, -0.0027]], rank4: device='cuda:4', requires_grad=True)`

Yidhar commented 1 week ago

Could you provide more information about your configuration, hardware and environment? As a common solution, reducing the batch size or accumulate step might help.您能否提供有关您的配置、硬件和环境的更多信息?作为一种常见的解决方案,减小 batch size 或 accumulation step 可能会有所帮助。 We use H1008 and use xformers to train the two te simultaneously. The batch size is 1, and the accumulation step is 8, but the results are generally the same. Initially only 58g memory is required, but it will soon increase to more than 80g resulting in out-of-memory. The key erro error occurs when I try to optimize with deepseed ZeRo我们使用 H1008 并使用 xformers 同时训练这两个 te。batch size 为 1,accumulation step 为 8,但结果通常相同。最初只需要 58g 内存,但很快就会增加到 80g 以上,从而导致内存不足。当我尝试使用 deepseed ZeRo 进行优化时出现关键错误 [rank4]: KeyError: Parameter containing: [rank4]: tensor([[-0.0128, -0.0060, -0.0098, ..., 0.0004, -0.0190, -0.0070], [rank4]: [ 0.0079, 0.0048, 0.0157, ..., 0.0004, -0.0063, -0.0040], [rank4]: [-0.0032, 0.0089, -0.0058, ..., -0.0145, 0.0044, 0.0134], [rank4]: ..., [rank4]: [-0.0035, 0.0126, -0.0035, ..., 0.0159, -0.0176, 0.0225], [rank4]: [-0.0124, -0.0085, 0.0201, ..., 0.0202, 0.0069, -0.0044], [rank4]: [-0.0029, -0.0132, -0.0124, ..., 0.0118, -0.0026, -0.0027]], [rank4]: device='cuda:4', requires_grad=True)

Some python environments are as follows:

torch 2.4.1 xformers .0.28.post1 lightning 2.4.0 deepspeed 0.15.2 diffusers 0.30.3

liesened commented 1 week ago

Can you provide more information on how to reproduce this issue?

Yidhar commented 1 week ago

deepspeed config: { "train_batch_size": 8 "zero_optimization": { "stage": 1, "allgather_partitions": true, "allgather_bucket_size": 1e9, "reduce_scatter": true, "reduce_bucket_size": 1e9, "overlap_comm": true, "contiguous_gradients": true }, "bf16": { "enabled": true } } train_yaml: `name: test-run target: modules.train_sdxl_hezi_rantag.setup

trainer:
  model_path: /data56/noob_hercules3/fp16/checkpoint-e1_s54544.safetensors
  batch_size: 1
  resolution: 1024
  world_size: 8
  seed: 114514
  wandb_id: "sdxl"
  use_xformers: true
  accumulate_grad_batches: 1
  gradient_clip_val: 0.0
  save_format: safetensors
  checkpoint_dir: "/data/sdxl/noob_hercules4/checkpoint"
  checkpoint_freq: 1
  checkpoint_steps: 1000
  save_weights_only: true
  max_epochs: 60
  max_steps: -1
  gradient_checkpointing: true

advanced:
  vae_encode_batch_size: -1 # same as batch_size
  train_text_encoder_1: true
  train_text_encoder_2: true
  text_encoder_1_lr: 3e-6
  text_encoder_2_lr: 3e-6
  offset_noise: true
  offset_noise_val: 0.0375
  min_snr: true
  min_snr_val: 5
  timestep_start: 0
  timestep_end: 1000
  v_parameterization: false
  zero_terminal_snr: false
  do_edm_style_training: false

lightning:
  accelerator: gpu
  devices: 8
  strategy: deepspeed
  precision: bf16-mixed
  num_nodes: 1
dataset:
  name: data_loader.arrow_load_stream.TextImageArrowStream
  target_area: 1_048_576 # 1024*1024
  index_file: "dataset/porcelain/jsons/porcelain_mt.json"
  multireso: true
  num_workers: 1
  min_size: 512
  max_size: 2048
  img_path: "/root/niji-anime-1"
  random_flip: false
  # process_batch_fn: "data.processors.shuffle_prompts_sdstyle"
  max_token_length: 225 # [75, 150, 225]

optimizer:
  name: torch.optim.AdamW
  params:
    lr: 1e-5
    weight_decay: 1e-2

scheduler:
  name: transformers.get_constant_schedule_with_warmup
  params:
    num_warmup_steps: 0
    last_epoch: -1`

`error:rank0: Traceback (most recent call last): rank0: File "/data/naifu/trainer.py", line 58, in

rank0: File "/data/naifu/trainer.py", line 54, in main rank0: Trainer(fabric, config).train_loop() rank0: File "/data/naifu/common/trainer.py", line 23, in init rank0: model, dataset, dataloader, optimizer, scheduler = model_cls(fabric, config) rank0: File "/data/naifu/modules/train_sdxl_hezi_rantag.py", line 102, in setup rank0: model.model, optimizer = fabric.setup(model.model, optimizer) rank0: File "/datavenv/envs/sdxlroot/lib/python3.10/site-packages/lightning/fabric/fabric.py", line 245, in setup rank0: module, optimizers = self._strategy.setup_module_and_optimizers( # type: ignoreassignment: File "/datavenv/envs/sdxlroot/lib/python3.10/site-packages/lightning/fabric/strategies/deepspeed.py", line 331, in setup_module_and_optimizers rank0: self._deepspeed_engine, optimizer = self._initialize_engine(module, optimizers[0]) rank0: File "/datavenv/envs/sdxlroot/lib/python3.10/site-packages/lightning/fabric/strategies/deepspeed.py", line 607, in _initialize_engine rank0: deepspeed_engine, deepspeedoptimizer, , _ = deepspeed.initialize( rank0: File "/datavenv/envs/sdxlroot/lib/python3.10/site-packages/deepspeed/init.py", line 193, in initialize rank0: engine = DeepSpeedEngine(args=args, rank0: File "/datavenv/envs/sdxlroot/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 313, in init rank0: self._configure_optimizer(optimizer, model_parameters) rank0: File "/datavenv/envs/sdxlroot/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1302, in _configure_optimizer rank0: self.optimizer = self._configure_zero_optimizer(basic_optimizer) rank0: File "/datavenv/envs/sdxlroot/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1560, in _configure_zero_optimizer rank0: optimizer = DeepSpeedZeroOptimizer( rank0: File "/datavenv/envs/sdxlroot/lib/python3.10/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 553, in init rank0: self._param_slice_mappings = self._create_param_mapping() rank0: File "/datavenv/envs/sdxlroot/lib/python3.10/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 574, in _create_param_mapping rank0: lp_name = self.param_nameslp: KeyError: Parameter containing: rank0: tensor([[-0.0051, 0.0369, 0.0221, ..., 0.0159, 0.0065, -0.0221], rank0: [ 0.0156, 0.0261, -0.0133, ..., -0.0037, 0.0043, 0.0036], rank0: [-0.0154, 0.0004, -0.0156, ..., -0.0206, 0.0015, -0.0057],

rank0: [ 0.0154, 0.0156, -0.0005, ..., 0.0102, -0.0199, 0.0145], rank0: [ 0.0037, -0.0017, 0.0007, ..., -0.0071, -0.0069, 0.0055], rank0: [-0.0009, 0.0109, 0.0019, ..., 0.0018, 0.0051, 0.0076]], rank0: device='cuda:0', requires_grad=True)`

This is an error I encountered when using deepspeed. but ,When I don't use deepspeed and the strategy is the default. Using the same training Settings will use 58g of memory for the first step, 78g for the second step, and an out-of-memory error for the third step

Mikubill commented 1 week ago

ok thanks for the report. I have made some modifications for deepspeed (w/ sdxl) integration and tested using:

lightning:
  accelerator: gpu
  devices: -1
  strategy: common.deepspeed._sdxl_strategy
  precision: bf16

for the full config please refer to config/train_sdxl_deepspeed.yaml Feel free to give it a try.