mindspore-lab / mindone

one for all, Optimal generator with No Exception
Apache License 2.0
328 stars 62 forks source link

fix(diffusers/pipelines): fix vae encoding in pipelines #531

Closed townwish4git closed 2 weeks ago

townwish4git commented 3 weeks ago

What does this PR do?

Fix incorrect calling for vae.diag_gauss_dist

In pipelines for img2img, inpainting and other tasks using images as inputs (such as StableDiffusionImg2ImgPipeline), image inputs would be encoded into a gaussian distribution by pipeline.vae.encode(), then pipelines would use retrieve_latents() function to sample a deterministic latent tensor from this distribution. retrieve_latents() offers parameter sample_mode to select the method of sampling latent tensor.

Argument sample_mode="argmax" is supposed to mean that $$latents = \mathop{argmax}\limits_{x}\ P_z(x)$$

,where $P_z$ is the probility of encoded gaussian distribution, which means latents is the mean of the distribution in this case. Therefore the origin codes should be corrected to:

def retrieve_latents(vae, encoder_output: ms.Tensor, sample_mode: str = "sample"):
    ...
    elif sample_mode == "argmax":
-        return vae.diag_gauss_dist.sample(encoder_output).argmax()
+        return vae.diag_gauss_dist.mode(encoder_output)
    ...

We fixed all the existing pipelines involving this function.

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.

@xxx