Open g-karthik opened 4 years ago
@g-karthik I see your problem. _BaseClassification
is also used for Precision
and Recall
where number of classes checking is important. Let's see what can be done for your case.
@g-karthik thinking about your problem, correct me if I wrongly understand it, you have the following:
y_pred1 = model(x1) # y_pred1.shape = (N, C1, ...)
np.max(y_true1) <= C1
y_pred2 = model(x2) # y_pred1.shape = (N, C2, ...)
np.max(y_true2) <= C2
y_pred3 = model(x3) # y_pred1.shape = (N, C3, ...)
np.max(y_true3) <= C3
What kind of model it is that you can vary like that the number of classes ? Does it make sense to pad y_predX
on maximum number of classes ?
Anyway, I'm trying to estimate the impact of adding a flag for accurary only to accept variable number of classes.
@vfdev-5 Thanks a lot for the quick response! Here's an example model, GPT2DoubleHeadsModel
from the transformers
library by Hugging Face.
For the multiple-choice classification task, in principle one could have a varying number of candidates/choices across examples in the test set.
So x1
might have C1
candidates, x2
might have C2
candidates, x3
might have C3
candidates and so on, like you showed above. The model would output logits over the candidate set for x1
, another logits over the candidate set for x2
, etc.
And we would know the multiple-choice label for x1
, the multiple-choice label for x2
, etc. So we should be able to compute the accuracy by computing match b/w logits and labels for x1
, then for x2
, and so on. And then finally we would have the total test set accuracy.
I think the padding on maximum number of classes should work, but that can cause memory issues when the maximum number of classes is very high compared to most of the examples.
Fixing the number of classes might make sense for training time to make life easier, but that might not be the case during inference time.
@vfdev-5 for now, I've unblocked myself by creating a copy of accuracy.py
where I raise a warning when num_classes
changes instead of raising a RuntimeError
.
also @vfdev-5, in the update()
method for Accuracy
:
elif self._type == "multiclass":
indices = torch.argmax(y_pred, dim=1)
correct = torch.eq(indices, y).view(-1)
this is basically computing recall@1 -- could you enhance the Accuracy
class to compute recall@k instead of just assuming k=1? Perhaps k
could be an argument for the __init__
method of the class?
@g-karthik maybe TopKCategoricalAccuracy can help ?
❓ Questions/Help/Support
I've been using the
Accuracy
class (inignite.metrics
) successfully so far in a multi-class setting to compute test set accuracy.Recently, I've hit a new scenario/dataset where the test set has different number of classes for each example. You can think of this as a multiple-choice selection task, where each example has a different number of candidates.
Now in this scenario, when I used
Accuracy
, I got the below error:ERROR:ignite.engine.engine.Engine:Current run is terminating due to exception: Input data number of classes has changed from 13 to 81.
I dived into the code and I see that the base class for
Accuracy
, i.e.,_BaseClassification
assumes that the number of classes/candidates is fixed across test set examples. And this is why I'm getting the above error at run-time.However, for multi-class accuracy, as you can see in the
update()
method here, we are simply computing the number of correct predictions viaargmax
along dimension-1 and measuring matches againsty
.This computation of
correct
should be accurate even if the number of candidates changes across two examples from the test set, right?So in effect, the computed accuracy would be correct even in such a case because the underlying variables are computed correctly?
What do you think, @vfdev-5? Is it possible to have a version of
Accuracy
for multi-class where there is no expectation that the number of classes is the same across all examples in the test set?