Chenxingyu1990 / A-Boundary-Based-Out-of-Distribution-Classifier-for-Generalized-Zero-Shot-Learning

MIT License
30 stars 3 forks source link

Testing Issue #6

Closed uqzhichen closed 3 years ago

uqzhichen commented 3 years ago

Hi, Chenxing,

I just noticed that your testing code may involve unfair behaviour.

In your testing code below, you first determine if the testing sample is from seen or unseen classes and then perform thresholding. However, being seen or unseen is not accessible during testing. I reckon you may need to judge the threshold first and then check the testing sample's domain source.

Please correct me if my understanding is wrong.

Cheers, Zhi

if kk.item() in unseen_labels.tolist(): unseen_all +=1 if dist.max()<thresholds[max_idx]: out = self.zsl_classifier(input_k.view(1,-1)) predlabel = torch.argmax(out,1) pred_label = self.data.unseen_labels[predlabel.cpu().data.item()]-1 pred.append(pred_label) unseen_count += 1 else: pred.append(1000) elif kk.item() in seen_labels.tolist(): seen_all += 1 if dist.max() >= thresholds[max_idx]: seen_count += 1 out = self.classifier(z_real[k,:].view(1,-1)) pred_label = torch.argmax(out,1).data.item()

pred_label = self.data.test_seen_labels[predlabel.cpu().data.item()]-1

    pred.append(pred_label)
else:
    pred.append(1000)
Chenxingyu1990 commented 3 years ago

The threshold is searched by using train data. We use the function below.

threshold = model_train_obj.search_thres_by_traindata(test_epoch, dataset = data, n = 0.95)

uqzhichen commented 3 years ago

I understand your thresholding operation. But the above code I mentioned is to determine the source of testing samples first. By "if kk.item() in unseen_labels.tolist():", you assume that you hold the source of the testing sample, but actually the source is not accessible in GZSl setting.

Chenxingyu1990 commented 3 years ago

Actually, "if kk.item() in unseen_labels.tolist()" is only used for calculating the accuracy. It counts the total number of unseen samples in the test set. It does not affect the domain classification result. Because the domain classification results are determined by "if dist.max()<thresholds[max_idx]:"

uqzhichen commented 3 years ago

Yea! I got it! Thank you for your clarification!