luosiallen / latent-consistency-model

Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference
MIT License
4.29k stars 221 forks source link

training SD2.1 with v_prediction dose not work? #65

Open dcfucheng opened 9 months ago

dcfucheng commented 9 months ago

Hi~ I tried the training scripts with SDv1.5 (prediction_type: epsilon) https://github.com/luosiallen/latent-consistency-model/blob/main/LCM_Training_Script/consistency_distillation/train_lcm_distill_sd_wds.py

It works and generates normal pictures, like

image

When I modified training scripts for SDv2.1 stabilityai/stable-diffusion-2-1 (prediction_type: v_prediction), the loss converge normally. But it can not generate correct pictures. It generates noise like this.

image

I also checked SDv2.0-base stabilityai/stable-diffusion-2-base (prediction_type: epsilon), it works and generates normal pictures.

Is the LCMScheduler wrong with prediction_type: v_prediction, or something should modified in train_lcm_distill_sd_wds.py for the pipeline with prediction_type: v_prediction?

Thanks~

smilekitty7 commented 8 months ago

Hi, I encountered the same problem. Have you solved it?

dcfucheng commented 8 months ago

Hi, I encountered the same problem. Have you solved it?

Maybe it works like this for v_prediction. You should correct 'x_prev' for v_prediction.

def predicted_epsilon(model_output, timesteps, sample, alphas, sigmas):
    sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
    alphas = extract_into_tensor(alphas, timesteps, sample.shape)
    pred_epsilon = alphas * model_output + sigmas * sample
    return pred_epsilon
pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0) 
pred_v = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output)

pred_noise2 = predicted_epsilon(
                                pred_v,
                                start_timesteps,
                                noisy_model_input,
                                alpha_schedule,
                                sigma_schedule,
                            )
x_prev = solver.ddim_step(pred_x0, pred_noise2, index)