huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
24.28k stars 5.02k forks source link

Provided pooled_prompt_embeds is overwritten via prompt_embeds[0] #7365

Open cloneofsimo opened 4 months ago

cloneofsimo commented 4 months ago

https://github.com/huggingface/diffusers/blob/25caf24ef90fc44074f4fd3712f6ed5a1db4a5c3/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L386

Simple fix:

pooled_prompt_embeds = prompt_embeds[0] if pooled_prompt_embeds is not None else pooled_prompt_embeds

Sorry this isn't a pr :P

sayakpaul commented 4 months ago

Golden catch. Please PR it.

bghira commented 4 months ago

the SDXL training scripts are broken this way as well @patrickvonplaten @sayakpaul which means the models that are trained with diffusers are likely all broken.

bghira commented 4 months ago

the prompt embed code also pools incorrectly. as tested by @AmericanPresidentJimmyCarter it looks like the correct method would be to pool the final layer while sampling hidden states from the penultimate layer.

currently, it just retrieves the first token and stresses that as if it were the pooled embed. the same thing happens for the negative prompt.

it should run mean(dim=1) over the final layer instead.

AmericanPresidentJimmyCarter commented 4 months ago

So I was unsure how it was pooled before, and we were going through the code trying to figure it out. It seems like

pooled_prompt_embeds = prompt_embeds[0]

Should really be

            pooled_prompt_embeds = list(filter(lambda x: x is not None, [
                getattr(prompt_embeds, 'pooler_output', None),
                getattr(prompt_embeds, 'text_embeds', None),
            ]))[0]

For clarity, as the two CLIP models output completely different classes and contain their pooled outputs in different properties. My concern was originally that instead of using the pooled output in the case of the one CLIP model, we were actually selecting the first token with [0] if that was the output of the last hidden layer instead of the pooled output. Simply referencing it as [0] is extremely unclear given the nature of the output from the two CLIP models.

sayakpaul commented 4 months ago

Order in which the tokenizers and text encoders are being passed matters, so I think the implementation is correct. If any comment would help, please file a PR, more than happy to work on a priority on that.

Simply referencing it as [0] is extremely unclear given the nature of the output from the two CLIP models.

That reason it's there is because it helps with torch.compile() otherwise there's a TensorVariableAccess problem.

github-actions[bot] commented 3 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

bghira commented 3 months ago

not stale

github-actions[bot] commented 2 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

bghira commented 2 months ago

still not stale :D

bghira commented 2 months ago

@sayakpaul i opened the pull request for this. but the code in question only runs when prompt embeds are None.

do we want to mix and match provided pooled embeds with generated prompt embeds?