pytorch / ignite

High-level library to help with training and evaluating neural networks in PyTorch flexibly and transparently.
https://pytorch-ignite.ai
BSD 3-Clause "New" or "Revised" License
4.55k stars 620 forks source link

Accuracy computation with variable classes #802

Open g-karthik opened 4 years ago

g-karthik commented 4 years ago

❓ Questions/Help/Support

I've been using the Accuracy class (in ignite.metrics) successfully so far in a multi-class setting to compute test set accuracy.

Recently, I've hit a new scenario/dataset where the test set has different number of classes for each example. You can think of this as a multiple-choice selection task, where each example has a different number of candidates.

Now in this scenario, when I used Accuracy, I got the below error: ERROR:ignite.engine.engine.Engine:Current run is terminating due to exception: Input data number of classes has changed from 13 to 81.

I dived into the code and I see that the base class for Accuracy, i.e., _BaseClassification assumes that the number of classes/candidates is fixed across test set examples. And this is why I'm getting the above error at run-time.

However, for multi-class accuracy, as you can see in the update() method here, we are simply computing the number of correct predictions via argmax along dimension-1 and measuring matches against y.

This computation of correct should be accurate even if the number of candidates changes across two examples from the test set, right?

So in effect, the computed accuracy would be correct even in such a case because the underlying variables are computed correctly?

What do you think, @vfdev-5? Is it possible to have a version of Accuracy for multi-class where there is no expectation that the number of classes is the same across all examples in the test set?

vfdev-5 commented 4 years ago

@g-karthik I see your problem. _BaseClassification is also used for Precision and Recall where number of classes checking is important. Let's see what can be done for your case.

vfdev-5 commented 4 years ago

@g-karthik thinking about your problem, correct me if I wrongly understand it, you have the following:

y_pred1 = model(x1)  # y_pred1.shape = (N, C1, ...)
np.max(y_true1) <= C1
y_pred2 = model(x2)  # y_pred1.shape = (N, C2, ...)
np.max(y_true2) <= C2
y_pred3 = model(x3)  # y_pred1.shape = (N, C3, ...)
np.max(y_true3) <= C3

What kind of model it is that you can vary like that the number of classes ? Does it make sense to pad y_predX on maximum number of classes ? Anyway, I'm trying to estimate the impact of adding a flag for accurary only to accept variable number of classes.

g-karthik commented 4 years ago

@vfdev-5 Thanks a lot for the quick response! Here's an example model, GPT2DoubleHeadsModel from the transformers library by Hugging Face.

For the multiple-choice classification task, in principle one could have a varying number of candidates/choices across examples in the test set.

So x1 might have C1 candidates, x2 might have C2 candidates, x3 might have C3 candidates and so on, like you showed above. The model would output logits over the candidate set for x1, another logits over the candidate set for x2, etc.

And we would know the multiple-choice label for x1, the multiple-choice label for x2, etc. So we should be able to compute the accuracy by computing match b/w logits and labels for x1, then for x2, and so on. And then finally we would have the total test set accuracy.

I think the padding on maximum number of classes should work, but that can cause memory issues when the maximum number of classes is very high compared to most of the examples.

Fixing the number of classes might make sense for training time to make life easier, but that might not be the case during inference time.

g-karthik commented 4 years ago

@vfdev-5 for now, I've unblocked myself by creating a copy of accuracy.py where I raise a warning when num_classes changes instead of raising a RuntimeError.

g-karthik commented 4 years ago

also @vfdev-5, in the update() method for Accuracy:

        elif self._type == "multiclass":
            indices = torch.argmax(y_pred, dim=1)
            correct = torch.eq(indices, y).view(-1)

this is basically computing recall@1 -- could you enhance the Accuracy class to compute recall@k instead of just assuming k=1? Perhaps k could be an argument for the __init__ method of the class?

vfdev-5 commented 4 years ago

@g-karthik maybe TopKCategoricalAccuracy can help ?