yuanzhi-zhu / prolific_dreamer2d

Unofficial implementation of 2D ProlificDreamer
137 stars 6 forks source link

Strange artifacts after setting rgb_as_latents to false #11

Closed daveredrum closed 1 year ago

daveredrum commented 1 year ago

Hi @yuanzhi-zhu,

thanks a lot for this awesome implementation of VSD!

I noticed that the training objective (i.e. “particle”) is a Gaussian white noise in latent space (see: https://github.com/yuanzhi-zhu/prolific_dreamer2d/blob/main/prolific_dreamer2d.py#L274), and the final image was decoded by VAE of the diffusion model. Comments suggest that pure Gaussian in image space will result in weird artifacts. Similar behavior was observed after I set rgb_as_latents to false (see the attached figure).

final_image_a_photograph_of_an_astronaut_riding_a_horse

I'm wondering what could be the reason for this? Is this some trick from the original paper?

yuanzhi-zhu commented 1 year ago

Hi @daveredrum , That's an excellent question.

Since the UNet is trained in the latent space of the VAE, the initial distribution is Gaussian in the latent space, too. That's to say, for general text-to-image sampling, we should make sure the initial input is Gaussian in the latent space. Here are the results of text-to-image sampling (set --generation_mode 't2i') with Gaussian input in latent space and RGB space: (for the second image I replace https://github.com/yuanzhi-zhu/prolific_dreamer2d/blob/main/prolific_dreamer2d.py#L274 with

particles = torch.randn((args.batch_size, 3, args.height, args.width)).to(device, dtype=dtype)
rgb_BCHW_512 = F.interpolate(particles, (512, 512), mode="bilinear", align_corners=False)
particles = vae.config.scaling_factor * vae.encode(rgb_BCHW_512).latent_dist.sample().detach().clone()

a_photograph_of_an_astronaut_riding_a_horse_prgressive a_photograph_of_an_astronaut_riding_a_horse_prgressive We've verified that for a non-Gaussian input, the diffusion model leads us nowhere!!

However, SDS and VSD are optimization based methods rather than sampling, which means the optimization will land at some reasonable outputs regardless of the input: a_photograph_of_an_astronaut_riding_a_horse_prgressive a_photograph_of_an_astronaut_riding_a_horse_prgressive Initialization does not matter, which means we have to blame something else!!

To the best of my knowledge, the artifact is caused by the VAE because the loss has to backpropagate through it when rgb_as_latents is set to false.

Hopefully, this issue can be addressed by the IF model which does not have VAE.

daveredrum commented 1 year ago

Thanks a lot for the clarification! It seems operating on the VAE's latent space is the best way to mitigate artifacts. Closing the issue now ^^

yuanzhi-zhu commented 1 year ago

Hi @daveredrum ,

I just found that you can simply try more optimization steps to get better results (while still worse than setting rgb_as_latents as True). Here is the result of VSD with 2000 steps: a_photograph_of_an_astronaut_riding_a_horse_image_step2016_t363

This problem is also discussed at https://github.com/ashawkey/stable-dreamfusion/issues/96

Also in the implementation of ThreeStudio, they by default set rgb_as_latents as False.

daveredrum commented 1 year ago

I tried to reproduce the image you shared, but the results I got still seem off...

Here are images I optimized from two runs:

final_image_a_photograph_of_an_astronaut_riding_a_horse

final_image_a_photograph_of_an_astronaut_riding_a_horse

Below is the script I used to generate those images:

# Image generation with prolific_dream 2d 
### VSD
python prolific_dreamer2d.py \
        --num_steps 2000 --log_steps 50 \
        --seed 1024 --lr 0.03 --phi_lr 0.0001 --use_t_phi true \
        --model_path 'stabilityai/stable-diffusion-2-1-base' \
        --loss_weight_type '1m_alphas_cumprod' --t_schedule 't_stages2' \
        --generation_mode 'vsd' \
        --phi_model 'lora' --lora_scale 1. --lora_vprediction false \
        --prompt "a photograph of an astronaut riding a horse" \
        --height 512 --width 512 --batch_size 1 --guidance_scale 7.5 \
        --log_progress true --save_x0 true --save_phi_model true \
        --rgb_as_latents false

Did you also modify other parameters, e.g. the learning rate?

yuanzhi-zhu commented 1 year ago

This might be caused by the _tschedule you chose.

Nevertheless, the outcomes when setting rgb_as_latents to False are notably inferior compared to when it's set to True.