pytorch / ignite

High-level library to help with training and evaluating neural networks in PyTorch flexibly and transparently.
https://pytorch-ignite.ai
BSD 3-Clause "New" or "Revised" License
4.49k stars 605 forks source link

Label-wise metrics (Accuracy etc.) for multi-label problems #513

Open jphdotam opened 5 years ago

jphdotam commented 5 years ago

Hi,

I've made a multi-label classifier using BCEWithLogitsLoss. In summary a data sample can be one of 3 binary classes, which aren't mutually eclusive, so y_pred and y can look something like [0, 1, 1].

My metrics include Accuracy(output_transform=thresholded_output_transform, is_multilabel=True) and Precision(output_transform=thresholded_output_transform, is_multilabel=True, average=True)}.

However, I'm interesting in having label-specific metrics (i.e. having 3 accuracies etc.). This is important because it allows me to see what labels are compromising my overall accuracy the most (a 70% accuracy be a 30% error in a single label, or a more modest error scattered across 3 labels).

There is no option to disable averaging for Accuracy() as with the others, and setting average=False for Precision() does not do what I expected (it yields a binary result per datum, not per label, so I end up with a tensor of size 500, not 3, if my dataset n=500).

Is there a way to get label-wise metrics in mutlilabel problems? Or a plan to introduce it?

P.S. I'd love to get an invite to the slack workspace if possible? How do I go about doing that?

vfdev-5 commented 5 years ago

@jphdotam thanks for the feedback! You are correct, multi-label case is always averaged for now for Accuracy, Precision, Recall.

Is there a way to get label-wise metrics in mutlilabel problems? Or a plan to introduce it?

There is an issue with a similar requirement https://github.com/pytorch/ignite/issues/467 For instance we have not much bandwidth to work on that. If you can send a PR for that, we'll be awesome.

P.S. I'd love to get an invite to the slack workspace if possible? How do I go about doing that?

You can find a link for that here : https://pytorch.org/resources

jphdotam commented 5 years ago

Many thanks, I've made a pull request here: https://github.com/pytorch/ignite/pull/516

I'm quite new to working on large projects so apologies if I have gone about this inappropriately.

jphdotam commented 5 years ago

In the mean time whilst the core team decide how best to implement this, this is a custom class I've made for the task which inherits from Accuracy:

class LabelwiseAccuracy(Accuracy):
    def __init__(self, output_transform=lambda x: x):
        self._num_correct = None
        self._num_examples = None
        super(LabelwiseAccuracy, self).__init__(output_transform=output_transform)

    def reset(self):
        self._num_correct = None
        self._num_examples = 0
        super(LabelwiseAccuracy, self).reset()

    def update(self, output):

        y_pred, y = self._check_shape(output)
        self._check_type((y_pred, y))

        num_classes = y_pred.size(1)
        last_dim = y_pred.ndimension()
        y_pred = torch.transpose(y_pred, 1, last_dim - 1).reshape(-1, num_classes)
        y = torch.transpose(y, 1, last_dim - 1).reshape(-1, num_classes)
        correct_exact = torch.all(y == y_pred.type_as(y), dim=-1)  # Sample-wise
        correct_elementwise = torch.sum(y == y_pred.type_as(y), dim=0)

        if self._num_correct is not None:
            self._num_correct = torch.add(self._num_correct,
                                                    correct_elementwise)
        else:
            self._num_correct = correct_elementwise
        self._num_examples += correct_exact.shape[0]

    def compute(self):
        if self._num_examples == 0:
            raise NotComputableError('Accuracy must have at least one example before it can be computed.')
        return self._num_correct.type(torch.float) / self._num_examples
crypdick commented 4 years ago

For anyone trying to use @jphdotam code in https://github.com/pytorch/ignite/issues/513#issuecomment-488983281 ,

y_pred, y = self._check_shape(output)

throws an exception because that function now returns nothing. Instead, use

self._check_shape(output)
y_pred, y = output

However, there's something wrong with it because I'm getting 'labelwise_accuracy': [0.9070000648498535, 0.8530000448226929, 0.8370000123977661, 0.7450000643730164, 0.8720000386238098, 0.7570000290870667, 0.9860000610351562, 0.9190000295639038, 0.8740000128746033] when 'avg_accuracy': 0.285

Edit: nvm, I stepped thru the code and it was fine. The bug was on my end. Cheers!