Closed kumam92 closed 3 years ago
Hi,
All classes are relabeled to 0-4 for a 5-way classification task.
To get the real label index in the dataset, please use
real_label = batch[1]
Thank you for reply. I had idea about extracting actual label but not sure how to use them for confusion matrix. As logits value coming out of model is still based on relabeled label [0-4] and based on this relabeled value it compares the accuracy of model. Is it possible to convert output of logits into actual label instead of relabeled value of 0-4. As based on actual label confusion matrix can be plotted otherwise its not possible to plot.
I would suggest below approach- You can try to do a mapping between actual label and psudo label. Once you get prediction by prediction = torch.argmax(logits, dim=1)
you can map it back to actual label. I believe you need to write 3-4 lines of additional code for 3 way mapping- real_label-> relabeled_label-> prediction. Then you can use it as normally you do for confusion matrix by comparing real_label and mapped prediction_label.
I agree with kumarmanas. In each task, you may construct a mapping between the real label batch[1]
and the re-indexed label label
.
Assume there are S classes in total, and each time we sample an N-way task. Then we can use the previous mapping to fill the N by N submatrix in the size-S confusion matrix. By sampling a lot of tasks, you may get an estimation of the full confusion matrix.
thank you, @kumarmanas and @Han-Jia , mapping thing worked fine for me.
I am going through the code of testing part and I wanted to know which images are incorrectly classified by model. So when I printed the data and labels in each batch I found that labels are always from 0 to 4 (5 way 5 shot classification). So Is this correct? as in every batch different set of image classes is present but its always prints label from 0 to 5. So if I try to print confusion matrix stats/confusion matrix its give wrong result as labels are always 0 to 4 instead of actual label. Part of code to look into- ` with torch.no_grad():
To plot confusion matrix Do I have to comment out one hot encoding part?