tianweiy / DMD2

(NeurIPS 2024 Oral 🔥) Improved Distribution Matching Distillation for Fast Image Synthesis
Other
526 stars 28 forks source link

train a 4step SDXL got AssertionError: Invalid: ``_post_backward_hook_state` #41

Open joelulu opened 3 months ago

joelulu commented 3 months ago

Thanks for excellent work! When I try to train a 4step SDXL model.(2 nodes 16 GPUs ) I got an error:

`rank2: Traceback (most recent call last): rank2: File "/mnt/nas/gaohl/project/DMD2-main/main/train_sd.py", line 739, in

rank2: File "/mnt/nas/gaohl/project/DMD2-main/main/train_sd.py", line 633, in train

rank2: File "/mnt/nas/gaohl/project/DMD2-main/main/train_sd.py", line 390, in train_one_step

rank2: File "/usr/local/lib/python3.10/site-packages/accelerate/accelerator.py", line 2159, in backward

rank2: File "/usr/local/lib/python3.10/site-packages/torch/_tensor.py", line 525, in backward

rank2: File "/usr/local/lib/python3.10/site-packages/torch/autograd/init.py", line 267, in backward

rank2: File "/usr/local/lib/python3.10/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward rank2: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass rank2: File "/usr/local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context rank2: return func(*args, **kwargs) rank2: File "/usr/local/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 1099, in _post_backward_final_callback

rank2: File "/usr/local/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 1168, in _finalize_params

rank2: File "/usr/local/lib/python3.10/site-packages/torch/distributed/utils.py", line 146, in _p_assert rank2: raise AssertionError(s) rank2: AssertionError: Invalid: _post_backward_hook_state: (<torch.autograd.graph.register_multi_grad_hook..Handle object at 0x7f0bfb6a3a00>,)

and my configure file is:

accelerate launch --main_process_port $MASTER_PORT --main_process_ip $MASTER_ADDR --config_file fsdp_configs/fsdp_8node_debug_joe.yaml --machine_rank $RANK main/train_sd.py \ --generator_lr 5e-7 \ --guidance_lr 5e-7 \ --train_iters 100000000 \ --output_path $CHECKPOINT_PATH/sdxl_cond999_8node_lr5e-7_denoising4step_diffusion1000_gan5e-3_guidance8_noinit_noode_backsim_scratch \ --batch_size 2 \ --grid_size 2 \ --initialie_generator --log_iters 1000 \ --resolution 1024 \ --latent_resolution 128 \ --seed 10 \ --real_guidance_scale 8 \ --fake_guidance_scale 1.0 \ --max_grad_norm 10.0 \ --model_id "/mnt/nas/gaohl/models/models--stabilityai--stable-diffusion-xl-base-1.0/snapshots/462165984030d82259a11f4367a4eed129e94a7b/" \ --wandb_iters 100 \ --wandb_entity $WANDB_ENTITY \ --wandb_project $WANDB_PROJECT \ --wandb_name "sdxl_cond999_8node_lr5e-7_denoising4step_diffusion1000_gan5e-3_guidance8_noinit_noode_backsim_scratch" \ --log_loss \ --dfake_gen_update_ratio 5 \ --fsdp \ --sdxl \ --use_fp16 \ --max_step_percent 0.98 \ --cls_on_clean_image \ --gen_cls_loss \ --gen_cls_loss_weight 5e-3 \ --guidance_cls_loss_weight 1e-2 \ --diffusion_gan \ --diffusion_gan_max_timestep 1000 \ --denoising \ --num_denoising_step 4 \ --denoising_timestep 1000 \ --backward_simulation \ --train_prompt_path $CHECKPOINT_PATH/captions_laion_score6.25.pkl \ --real_image_path $CHECKPOINT_PATH/sdxl_vae_latents_laion_500k_lmdb/

fsdp_8node_debug_joe.yaml: compute_environment: LOCAL_MACHINE debug: true distributed_type: FSDP downcast_bf16: 'no' fsdp_config: fsdp_auto_wrap_policy: SIZE_BASED_WRAP fsdp_backward_prefetch_policy: BACKWARD_PRE fsdp_forward_prefetch: false fsdp_min_num_params: 3000 fsdp_offload_params: false fsdp_sharding_strategy: 1 fsdp_state_dict_type: SHARDED_STATE_DICT fsdp_sync_module_states: true fsdp_use_orig_params: false machine_rank: 0 main_training_function: main mixed_precision: 'no' num_machines: 1 num_processes: 8 rdzv_backend: static same_network: true tpu_env: [] tpu_use_cluster: false tpu_use_sudo: false use_cpu: false

tianweiy commented 3 months ago

I have no clue how to solve this error.

But the most likely cause for the error is mismatched torch and accelerate version. Could you double check that the torch and accelerate version match the one in the README ?

Other versions simply doesn't work unfortunately...

also related to https://github.com/tianweiy/DMD2/issues/25#issuecomment-2222350269

fire2323 commented 3 months ago

@tianweiy, do you use TorchDynamo in torch fsdp, because I found it might be related to torchdynamo compiling.

tianweiy commented 3 months ago

I didn't use TorchDynamo.

BeBuBu commented 3 months ago

Is there any new progress on this question? My version of accelerate and torch is as follows: accelerate 0.25.0 pytorch 2.1.2 py3.8_cuda11.8_cudnn8.7.0_0

tianweiy commented 3 months ago

Please just try the one specified in the readme first. Other versions are not tested and likely just don't work

On Sun, Aug 11, 2024, 10:54 PM Lijian @.***> wrote:

Is there any new progress on this question? My version of accelerate and torch is as follows: accelerate 0.25.0 pytorch 2.1.2 py3.8_cuda11.8_cudnn8.7.0_0

— Reply to this email directly, view it on GitHub https://github.com/tianweiy/DMD2/issues/41#issuecomment-2283158960, or unsubscribe https://github.com/notifications/unsubscribe-auth/AJFWY3T3AX2BXFYKEK7GTCDZRBE2LAVCNFSM6AAAAABMIEHMSOVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDEOBTGE2TQOJWGA . You are receiving this because you were mentioned.Message ID: @.***>

joelulu commented 3 months ago

I have no clue how to solve this error.

But the most likely cause for the error is mismatched torch and accelerate version. Could you double check that the torch and accelerate version match the one in the README ?

Other versions simply doesn't work unfortunately...

also related to #25 (comment)

Thanks a lot. I solved my problem

BeBuBu commented 3 months ago

Please just try the one specified in the readme first. Other versions are not tested and likely just don't work On Sun, Aug 11, 2024, 10:54 PM Lijian @.> wrote: Is there any new progress on this question? My version of accelerate and torch is as follows: accelerate 0.25.0 pytorch 2.1.2 py3.8_cuda11.8_cudnn8.7.0_0 — Reply to this email directly, view it on GitHub <#41 (comment)>, or unsubscribe https://github.com/notifications/unsubscribe-auth/AJFWY3T3AX2BXFYKEK7GTCDZRBE2LAVCNFSM6AAAAABMIEHMSOVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDEOBTGE2TQOJWGA . You are receiving this because you were mentioned.Message ID: @.>

Thank you. Problem solved

joelulu commented 3 months ago

Does DMD2 have plans to support StableCascade distillation ?