tianweiy / DMD2

Other
386 stars 22 forks source link

resuming from the best checkpoint and training for an additional 2,000 iterations. #29

Closed SYZhang0805 closed 1 month ago

SYZhang0805 commented 1 month ago

When I tried to resume from the best checkpoint and train for an additional 2,000 iterations, I got an error as follow: Traceback (most recent call last): File "main/train_sd.py", line 700, in trainer.train() File "main/train_sd.py", line 599, in train self.train_one_step() File "main/train_sd.py", line 355, in train_one_step generator_loss_dict, generator_log_dict = self.model( File "/home/amax/.conda/envs/dmd2/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, kwargs) File "/home/amax/Documents/zsy/DMD2-main/main/sd_unified_model.py", line 305, in forward loss_dict, log_dict = self.guidance_model( File "/home/amax/.conda/envs/dmd2/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, *kwargs) File "/home/amax/.conda/envs/dmd2/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1156, in forward output = self._run_ddp_forward(inputs, kwargs) File "/home/amax/.conda/envs/dmd2/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1110, in _run_ddp_forward return module_to_run(*inputs[0], kwargs[0]) # type: ignore[index] File "/home/amax/.conda/envs/dmd2/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, *kwargs) File "/home/amax/Documents/zsy/DMD2-main/main/sd_guidance.py", line 437, in forward loss_dict, log_dict = self.generator_forward( File "/home/amax/Documents/zsy/DMD2-main/main/sd_guidance.py", line 339, in generator_forward dm_dict, dm_log_dict = self.compute_distribution_matching_loss( File "/home/amax/Documents/zsy/DMD2-main/main/sd_guidance.py", line 192, in compute_distribution_matching_loss pred_fake_noise = predict_noise( File "/home/amax/Documents/zsy/DMD2-main/main/sd_guidance.py", line 36, in predict_noise noise_pred = unet(model_input, timesteps, embeddings, added_cond_kwargs=unet_added_conditions).sample File "/home/amax/.conda/envs/dmd2/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(args, kwargs) File "/home/amax/Documents/zsy/DMD2-main/main/sd_unet_forward.py", line 122, in classify_forward t_emb = self.get_time_embed(sample=sample, timestep=timestep) File "/home/amax/.conda/envs/pathTracing/lib/python3.8/site-packages/diffusers/models/modeling_utils.py", line 221, in getattr return super().getattr(name) File "/home/amax/.conda/envs/dmd2/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1614, in getattr raise AttributeError("'{}' object has no attribute '{}'".format( AttributeError: 'UNet2DConditionModel' object has no attribute 'get_time_embed'

Here is my command to run the training: CUDA_VISIBLE_DEVICES=0,1 torchrun --nnodes 1 --nproc_per_node=2 --rdzv_id=2345 \ --rdzv_backend=c10d \ --rdzv_endpoint=$MASTER_IP main/train_sd.py \ --generator_lr 5e-7 \ --guidance_lr 5e-7 \ --train_iters 41000 \ --output_path log/output \ --cache_dir log/cache \ --log_path log/log \ --batch_size 2 \ --grid_size 2 \ --initialie_generator --log_iters 1000 \ --resolution 512 \ --latent_resolution 64 \ --seed 10 \ --real_guidance_scale 1.75 \ --fake_guidance_scale 1.0 \ --max_grad_norm 10.0 \ --model_id "runwayml/stable-diffusion-v1-5" \ --train_prompt_path $CHECKPOINT_PATH/captions_laion_score6.25.pkl \ --wandb_iters 50 \ --wandb_entity $WANDB_ENTITY \ --wandb_project $WANDB_PROJECT \ --wandb_name "laion6.25_sd_baseline_8node_guidance1.75_lr5e-7_seed10_dfake10_diffusion1000_gan1e-3_resume" \ --use_fp16 \ --log_loss \ --dfake_gen_update_ratio 10 \ --gradient_checkpointing \ --cls_on_clean_image \ --gen_cls_loss \ --gen_cls_loss_weight 1e-3 \ --guidance_cls_loss_weight 1e-2 \ --diffusion_gan \ --diffusion_gan_max_timestep 1000 \ --ckpt_only_path $CHECKPOINT_PATH/laion6.25_sd_baseline_8node_guidance1.75_lr1e-5_seed10_dfake10_from_scratch_fid9.28_checkpoint_model_039000 \ --real_image_path $CHECKPOINT_PATH/sd_vae_latents_laion_500k_lmdb/

tianweiy commented 1 month ago

you might need to downgrade the diffuser version to something like 0.27.2