Closed uqzhichen closed 3 years ago
The threshold is searched by using train data. We use the function below.
threshold = model_train_obj.search_thres_by_traindata(test_epoch, dataset = data, n = 0.95)
I understand your thresholding operation. But the above code I mentioned is to determine the source of testing samples first. By "if kk.item() in unseen_labels.tolist():", you assume that you hold the source of the testing sample, but actually the source is not accessible in GZSl setting.
Actually, "if kk.item() in unseen_labels.tolist()" is only used for calculating the accuracy. It counts the total number of unseen samples in the test set. It does not affect the domain classification result. Because the domain classification results are determined by "if dist.max()<thresholds[max_idx]:"
Yea! I got it! Thank you for your clarification!
Hi, Chenxing,
I just noticed that your testing code may involve unfair behaviour.
In your testing code below, you first determine if the testing sample is from seen or unseen classes and then perform thresholding. However, being seen or unseen is not accessible during testing. I reckon you may need to judge the threshold first and then check the testing sample's domain source.
Please correct me if my understanding is wrong.
Cheers, Zhi
if kk.item() in unseen_labels.tolist(): unseen_all +=1 if dist.max()<thresholds[max_idx]: out = self.zsl_classifier(input_k.view(1,-1)) predlabel = torch.argmax(out,1) pred_label = self.data.unseen_labels[predlabel.cpu().data.item()]-1 pred.append(pred_label) unseen_count += 1 else: pred.append(1000) elif kk.item() in seen_labels.tolist(): seen_all += 1 if dist.max() >= thresholds[max_idx]: seen_count += 1 out = self.classifier(z_real[k,:].view(1,-1)) pred_label = torch.argmax(out,1).data.item()
pred_label = self.data.test_seen_labels[predlabel.cpu().data.item()]-1