TencentQQGYLab / ELLA

ELLA: Equip Diffusion Models with LLM for Enhanced Semantic Alignment
https://ella-diffusion.github.io/
Apache License 2.0
1.1k stars 57 forks source link

Details about the pooled embedding for SDXL #39

Closed jmliu88 closed 6 months ago

jmliu88 commented 6 months ago

Hi, nice work!

I'm trying to reimplement ELLA on SDXL, and get confused about the pooled_prompt_embedding.

The original SDXL: The output of SDXL text encoder comes from 2 CLIP text encoders and concatenated along the channel dimension, while the pooled_prompt_embedding only comes from the second one. The size of text embedding is 1772048 and the size of pooled_embedding is 1*1280.

It seems straightforward to build connectors that align with the size of text embedding. However, the pooled_prompt_embedding seems to be a special operator in SDXL structure. It would be very helpful if details about the pooled_prompt_embedding could be disclosed!

budui commented 6 months ago

we use AttentionPool to convert 77x2048 tensor to 1x1280

jmliu88 commented 6 months ago

Thanks. It helps!

eyalgutflaish commented 6 months ago

Hi, It seems that PerceiverResampler is able to do that by providing num_latents=1. Do you agree?

Additionally I guess that the pooling operation should have much less parameters since it's working on the output, i.e. the encoder_hidden_states which have gone a lot of processing. Do you agree? :)

jmliu88 commented 6 months ago

In my implementation, I pick the last latent's output as the pooled embedding, which is theoretically equal to AttentionPool, I guess.

And also I index the first 1280 channels out of the 2048 channels, to match the dimensions of SDXL. This line generate pooled embedding: pooled_embedding = embedding[:, -1, :1280]

My model converges with this implementation. However, a quick train with the SAM1b data from PixArt doesn't yield good prompt following results. Perhaps a re-captioning is needed.

andupotorac commented 3 months ago

@jmliu88 Did you have any luck with ELLA for SDXL?