huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
25.78k stars 5.32k forks source link

unsupported operand type(s) for // : 'NoneType' and 'int' when running diffusers/examples/consistency_distillation /train_lcm_distill_sd_wds.py #6518

Closed dreaming-huang closed 9 months ago

dreaming-huang commented 9 months ago

Describe the bug

reason: The class of .config of Unet is FrozenDict, teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim(line 925) may not change teacher_unet.config.time_cond_proj_dim,and make w_embedding = guidance_scale_embedding(w, embedding_dim=time_cond_proj_dim)(line 1128) get wrong input

should replace line 925 with the follow code teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim teacher_unet.config.time_cond_proj_dim = args.unet_time_cond_proj_dim

Reproduction

export MODEL_NAME="runwayml/stable-diffusion-v1-5" export OUTPUT_DIR="sd1-5" CUDA_VISIBLE_DEVICES=0,1 accelerate launch --multi_gpu --num_processes 2 train_lcm_distill_sd_wds.py --pretrained_teacher_model=$MODEL_NAME --output_dir=$OUTPUT_DIR --mixed_precision=fp16 \ --resolution=512 \ --learning_rate=1e-6 --loss_type="huber" --ema_decay=0.95 --adam_weight_decay=0.0 \ --max_train_steps=1000 \ --max_train_samples=4000000 \ --dataloader_num_workers=8 \ --train_shards_path_or_url="pipe:curl -L -s https://hf-mirror.com/datasets/laion/conceptual-captions-12m-webdataset/resolve/main/data/{00000..01099}.tar?download=true" \ --validation_steps=200 \ --checkpointing_steps=200 --checkpoints_total_limit=10 \ --train_batch_size=4 \ --gradient_checkpointing \ --gradient_accumulation_steps=1 \ --use_8bit_adam \ --resume_from_checkpoint=latest \ --report_to=wandb \ --seed=453645634

Logs

No response

System Info

Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.

Who can help?

No response

dg845 commented 9 months ago

I think this can probably be fixed by calling

    time_cond_proj_dim = (
        teacher_unet.config.time_cond_proj_dim
        if teacher_unet.config.time_cond_proj_dim is not None
        else args.unet_time_cond_proj_dim
    )
    unet = UNet2DConditionModel.from_config(teacher_unet.config, time_cond_proj_dim=time_cond_proj_dim)

instead of

https://github.com/huggingface/diffusers/blob/a551ddf928f02419432263f39dc0109b74f004f7/examples/consistency_distillation/train_lcm_distill_sd_wds.py#L924-L927

Let me open a PR to fix this.