BatsResearch / csp

Learning to compose soft prompts for compositional zero-shot learning.
BSD 3-Clause "New" or "Revised" License
83 stars 6 forks source link

Why do you use argmax() to find the eos token #8

Closed gulzainali98 closed 2 years ago

gulzainali98 commented 2 years ago

eos_idx = tokenized[idx].argmax() this gives me index number 3 on my custom words while using context length 16.


eos_idx = int(self.token_ids[0].argmax())
soft_embeddings = self.attr_dropout(self.soft_embeddings)
token_tensor[:, eos_idx - 2, :] = soft_embeddings[
    attr_idx
].type(self.clip_model.dtype)
token_tensor[:, eos_idx - 1, :] = soft_embeddings[
    obj_idx + self.offset
].type(self.clip_model.dtype)

How can you say that eos_idx-1 would be attribute and eos_idx-2 will be the place for object embedding in context length 8? What if i want another word in in between attribute and object? i.e 'a photo of car with wet look'. How would i manage to do replacement then?

nihalnayak commented 2 years ago

You bring up a good point. We assume the class names will be of the form <attribute> <object> in the compositional zero-shot learning task. But, our framework can be easily adapted to a more generalized setting as well.

Tokenization in CLIP. We use the clip tokenizer to build token ids for our prompt, i.e. the tokenizer converts tokens to ids. For example, if you tokenize the prompt a photo of car with wet look, you will get the following output:

In [3]: clip.tokenize("a photo of car with wet look", context_length=10)
Out[3]: tensor([[49406,   320,  1125,   539,  1615,   593,  6682,  1012, 49407,     0]])

where 49406 corresponds to the start token, 49407 corresponds to the end token, and 0 corresponds to the padding. You can play around with the context_length to add more padding. A smaller context_length creates a smaller attention mask in the clip architecture, which saves memory during training and testing. But, if the context_length is too small then it might throw an error that the input prompt is too long for the context length.

Extension. To extend our framework to your setting, you could have a map for the token id 1615 which corresponds to car to a custom token or learned representation, and replace them in your token_tensor.

Tip: Keep in mind that sometimes words split into subwords that might not correspond to a 1:1 mapping. For example, the word compositional corresponds to the token ids 7783, 760, 918,. You might need a heuristic to deal with these corner cases.

I hope this helps! Good luck!