huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
25.37k stars 5.25k forks source link

FP32 training for sd3 controlnet #9560

Open xduzhangjiayu opened 1 day ago

xduzhangjiayu commented 1 day ago

Hi, I have been use examples\controlnet\train_controlnet_sd3.py for controlnet training for a while, and I have some confusion and would like your advice

  1. In the line 1097: vae.to(accelerator.device, dtype=torch.float32) It seems we should use fp32 for VAE, but as far as I know, SD3 currently has no fp32 checkpoints, so does it really work if we populate fp16 into fp32?

  2. Before running the train script, accelerate config can specify whether to use mixed precision or not, since SD3 only has fp16 checkpoint at present, I don't know how to choose this option, whether to choose 'fp16' or 'no'.

Really appreciate your advice! @sayakpaul @DavyMorgan

DavyMorgan commented 1 day ago

Hi, here we use fp32 to prevent numerical instabilities, and you can check this. Also, in my experiments, I choose 'no' during accelerate config.

xduzhangjiayu commented 1 day ago

Thanks for the reply! The SD3 base model you used is fp32?

xduzhangjiayu commented 1 day ago

As author of ControlNet mentioned here, train ControlNet with FP16 will not work well. https://github.com/lllyasviel/ControlNet/issues/265

DavyMorgan commented 1 day ago

@sayakpaul If I understand it correctly, we cast the fp16 weight to fp32 to prevent numerical instabilities (SD3 currently has no fp32 checkpoints). @xduzhangjiayu Meanwhile, it seems that training ControlNet with FP16 rather than FP32 will not work well from https://github.com/lllyasviel/ControlNet/issues/265#issuecomment-1466354258

xduzhangjiayu commented 16 hours ago

@sayakpaul If I understand it correctly, we cast the fp16 weight to fp32 to prevent numerical instabilities (SD3 currently has no fp32 checkpoints). @xduzhangjiayu Meanwhile, it seems that training ControlNet with FP16 rather than FP32 will not work well from lllyasviel/ControlNet#265 (comment)

In my experiment, the loss began to oscillate continuously after ~10,000 steps without convergence, and the result image was also fuzzy and blocky but semantically correct. I am not sure whether this is related to the fp16 training.