XLabs-AI / x-flux

Apache License 2.0
1.54k stars 113 forks source link

How to use zero3 to train the model? #53

Open lgs00 opened 2 months ago

lgs00 commented 2 months ago

How to use zero3 to train the model? The use of zero3 can reduce the cuda memory consumption,

tran, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        tran, optimizer, train_dataloader, lr_scheduler
    )
print(f'*********{total_params:,} ptotal tran parameters with {torch.cuda.current_device()}.')

but when I use zero3, after accelerator.prepare, the calculated training parameter is 0, and the saved model is only 6MB, what is the problem? the config of zero3 is below:

compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
  gradient_accumulation_steps: 1
  gradient_clipping: 1.0
  offload_optimizer_device: none
  offload_param_device: none
  zero3_init_flag: false
  zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
sudanl commented 1 month ago

Hi! Have you solved this problem?