Closed dreaming-huang closed 9 months ago
I think this can probably be fixed by calling
time_cond_proj_dim = (
teacher_unet.config.time_cond_proj_dim
if teacher_unet.config.time_cond_proj_dim is not None
else args.unet_time_cond_proj_dim
)
unet = UNet2DConditionModel.from_config(teacher_unet.config, time_cond_proj_dim=time_cond_proj_dim)
instead of
Let me open a PR to fix this.
Describe the bug
reason: The class of .config of Unet is FrozenDict, teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim(line 925) may not change teacher_unet.config.time_cond_proj_dim,and make w_embedding = guidance_scale_embedding(w, embedding_dim=time_cond_proj_dim)(line 1128) get wrong input
should replace line 925 with the follow code teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim teacher_unet.config.time_cond_proj_dim = args.unet_time_cond_proj_dim
Reproduction
export MODEL_NAME="runwayml/stable-diffusion-v1-5" export OUTPUT_DIR="sd1-5" CUDA_VISIBLE_DEVICES=0,1 accelerate launch --multi_gpu --num_processes 2 train_lcm_distill_sd_wds.py --pretrained_teacher_model=$MODEL_NAME --output_dir=$OUTPUT_DIR --mixed_precision=fp16 \ --resolution=512 \ --learning_rate=1e-6 --loss_type="huber" --ema_decay=0.95 --adam_weight_decay=0.0 \ --max_train_steps=1000 \ --max_train_samples=4000000 \ --dataloader_num_workers=8 \ --train_shards_path_or_url="pipe:curl -L -s https://hf-mirror.com/datasets/laion/conceptual-captions-12m-webdataset/resolve/main/data/{00000..01099}.tar?download=true" \ --validation_steps=200 \ --checkpointing_steps=200 --checkpoints_total_limit=10 \ --train_batch_size=4 \ --gradient_checkpointing \ --gradient_accumulation_steps=1 \ --use_8bit_adam \ --resume_from_checkpoint=latest \ --report_to=wandb \ --seed=453645634
Logs
No response
System Info
Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.
diffusers
version: 0.26.0.dev0Who can help?
No response