xmed-lab / CLIPN

ICCV 2023: CLIPN for Zero-Shot OOD Detection: Teaching CLIP to Say No
MIT License
128 stars 12 forks source link

Question concerning OOD detection #7

Open romain-martin opened 9 months ago

romain-martin commented 9 months ago

First of all, thank you for your work. The method is promising and your article is very interesting, so I tried to use it in two way:

I'm using the .pt weights you kindly provided, and I tried to implement the ATD and the CTW methods. However the results were really bad leading me to think I missed something, on my first usecase the prompt was only: "A photo of a person with a {}" ("A photo of a person without a {}") with "hat", "cap", "helmet" as the class names. Using ATD everything is considered as an OOD, using CTW almost everything is considered as an ID. I have some question regarding your paper: Do you have a reference or a paper explaining where Eq.4 comes from? So regarding the CTW method, Eq.4 should be over 0.5 for the classification to be OOD. And also from where comes the Eq.8? As for the Eq.6, to compute pij, this is a kind of softmax right? Just adding the temperature parameter? In this case, wouldn't the ATD method be unusable when you only have one class and just want to discard the FP as pij is equal to 1? The first thing that came to my mind was to find the index of maximum value in logits, and check logits[index] > logits_no[index] to check if it's an ID or an OOD, however I suppose it's mathematically incorrect as you didn't mention it in your paper, and the test I ran also led to bad results.

Here are the functions I wrote for ATD and CTW from what I understood from your paper, they are kind of raw as it's a wip. I used the code in "handcrafted" folder, from what I understood this is the one to use when dealing with custom prompts and not the learned ones. Both of them takes the logits and logits_no computed this way: logits = F.normalize(feat, dim=-1, p=2) @ fc_yes.T logits_no = F.normalize(feat, dim=-1, p=2) @ fc_no.T As well as a tau parameter, I set it to 1 for now.

def CTW(logits_yes, logits_no, tau):
    yes = logits_yes[0].detach().tolist()
    no = logits_no[0].detach().tolist()
    pij = []
    denominator = 0
    for i in range(len(yes)):
        denominator += math.exp(yes[i] / tau)
    for i in range(len(yes)):
        pij.append(math.exp(yes[i] / tau) / denominator)
    pijno = []
    for i in range(len(no)):
        pijno.append(math.exp(no[i]/tau) / (math.exp(yes[i]/tau) + math.exp(no[i]/tau)))
    index = pij.index(max(pij))
    bestood = pijno[index]
    return (index, 1 - bestood > bestood)
def ATD(logits_yes, logits_no, tau):
    ood = 1.
    yes = logits_yes[0].detach().tolist()
    no = logits_no[0].detach().tolist()
    pijno = []
    for i in range(len(no)):
        pijno.append(math.exp(no[i]/tau)/(math.exp(yes[i]/tau) + math.exp(no[i]/tau)))
    pij = []
    denominator = 0
    for i in range(len(yes)):
        denominator += math.exp(yes[i]/tau)
    for i in range(len(yes)):
        pij.append(math.exp(yes[i]/tau)/denominator)
    index = pij.index(max(pij))
    for i, pno in enumerate(pijno):
        ood -= (1 - pno)*pij[i]
    res = 0
    for pyes in pij:
        if pyes > ood:
            res = 1
    return (index, res)

The return value is 1 if it's an ID and 0 otherwise. The model is in eval mode and I use process_test function returned by load_model() function to preprocess the images I load using Pil Image.open(). So I don't know if I did something wrong or if I "just" need to retrain the model. Thank for your help!

SiLangWHL commented 9 months ago

Hi,

Sorry for the late reply.

romain-martin commented 9 months ago

Hello, Thank for your answer,

ok I thought the weights already knew the negative keywords meaning. Just to be sure, as I may misunderstood the difference between hand-crafted prompts (1) and learnable ones (2). (1) Mean ClipN has two text encoders, one of them being used for negatives prompts to further use ATD or CTW strategy to enhance the original CLIP ability to perform zero-shot classification, meaning that it can assess if an image doesn't match the given categories. (2) Has the same purpose except that the negative weights have been learned so there is no need of the text encoders anymore, the weights are already embedded in the model? Correct me if I'm wrong, but in this case, isn't (1) more general so better suited for zero-shot classification?

Thank you again