rinongal / textual_inversion

MIT License
2.87k stars 278 forks source link

[Question] Which embedding is actually optimised? #120

Closed bonlime closed 1 year ago

bonlime commented 1 year ago

Hey! In the paper you mentioned (page 19).

However, rather than modifying the conditioning code cθ(y) directly, we change the conditioning text y.

From which it seems that you're optimising the input embedding to text encoder (it also seems from the Figure 2.). But it seems from code that you're actually optimising the embedding which directly goes to the LDM (output of text encoder), am I correct?

rinongal commented 1 year ago

No. The paper is correct, and the code follows the paper.

Here is the relevant line in the code where we actually do the replacement: https://github.com/rinongal/textual_inversion/blob/0a950b482d2e8f215122805d4c5901bdb4a6947f/ldm/modules/x_transformer.py#L615

You'll see that this is after the token embedding step, and before the transformer adds the positional encoding and runs all of the text encoder's actual layers.

bonlime commented 1 year ago

Thanks for a quick response. I got confused with the code in EmbeddingManager class but now see that you're right.

Is there any motivation for optimising the input embedding and not the output? It seems to me that it should be easier to do since optimisation path is shorter. Doing it in naive way may not work, but passing the learned embedding through final layer norm in text_encoder to ensure it's close to other embeddings could work. It also would be faster to train since we don't need to backprop text_encoder

rinongal commented 1 year ago

The motivation for optimizing the input is that it allows the text encoder to actually operate on these new words. Ideally we'd want the encoder to reason over both the new concepts and its prior knowledge, and it will have a harder time doing that if we optimize the output.

You can also look at the HuggingFace experiments here which show that optimizing the word embeddings / text encoder has a critical impact on result quality when compared to just tuning the u-net, for example.