qubvel / segmentation_models.pytorch

Semantic segmentation models with 500+ pretrained convolutional and transformer-based backbones.
https://smp.readthedocs.io/
MIT License
9.1k stars 1.62k forks source link

Get score per class #863

Open tabikhm opened 3 months ago

tabikhm commented 3 months ago

When trying with predictions and targets of shape (batch_size, num_classes, image_height, image_width):

tp, fp, fn, tn = smp.metrics.get_stats(predictions, targets, mode='multiclass', num_classes=9)
iou_score = smp.metrics.iou_score(tp, fp, fn, tn, reduction=None)
recall = smp.metrics.recall(tp, fp, fn, tn, reduction=None)
precision = smp.metrics.precision(tp, fp, fn, tn, reduction=None)
f1score = smp.metrics.f1_score(tp, fp, fn, tn, reduction=None)

I get the following tensor size for all returned metrics (tp, fp, fn, tn, iou_score, recall, precision): size ([batch_size, num_classes])

AFAIK, this gets me for iou_score, for example, the score of each prediction (row) per class. And If I average the tensor along its rows, I should get the mean value per class. Please correct me if this is wrong. Because if that is the case, I am getting a score of 1.0 for all classes except the first two. Despite having differences in the target and prediction.

iou_score:

tensor([[0.9594, 0.5560, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9878, 0.8009, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9873, 0.8099, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        ...,
        [0.9835, 0.7933, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9862, 0.8013, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [0.9672, 0.6682, 1.0000,  ..., 1.0000, 1.0000, 1.0000]])

iou_score for image 0: iou_score[0] tensor([0.9594, 0.5560, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])

plotting the target vs pred for class 4:

Target for class ID 4:

target= targets[0][4].to("cpu").numpy()
plt.imshow(target)

image

Prediction for the same class ID 4:

pred = predictions[0][4].to("cpu").numpy()
plt.imshow(pred)

image

I think that, because of this, I am getting very optimistic metrics despite it not being the case. When switching to "multilabel" instead of "multiclass", the results make more sense. Can someone explain that please?

github-actions[bot] commented 1 month ago

This issue is stale because it has been open 60 days with no activity. Remove stale label or comment or this will be closed in 7 days.