BatsResearch / csp

Learning to compose soft prompts for compositional zero-shot learning.
BSD 3-Clause "New" or "Revised" License
83 stars 6 forks source link

Soft Embeddings for attribute and objects #20

Closed ans92 closed 8 months ago

ans92 commented 8 months ago

Hi, Thank you for great work. I want to ask in csp_init function in csp.py file, you are assigning mean of orig_token_embedding to soft_embedding. Here you are taking the argmax(). In most cases the arg_max() is 2 so rep[1:eos_idx, :] is [1, 768]. So I want to know if this [1, 768] is representing your whole composition pair i.e. (attribute and object)? Some rep[1:eos_idx, :] are [2, 768] and one has [3, 768]. So are these representing the composition pair embeddings? If yes then why most has the shape [1, 768] and only few has shape [2, 768]? My second question is from your optimizer, it appears that you are only optimizing and back-propagating the loss to these soft-embeddings only and nothing from the Clip is being updated or trained? Am I right? Can you please guide me on both of these questions?

nihalnayak commented 8 months ago

Thanks for your questions!

To answer your first question, we get the arg_max() because we want to average the concepts (either attributes or objects) with more than 1 token into a single token embedding. This allows us to easily replace them in csp.py.

You are right about the second question. We are only optimizing the soft embeddings and nothing else. Our paper is introducing a new way soft prompting method where we only fine-tune the embeddings.

Hope this helps!