pytorch / ignite

High-level library to help with training and evaluating neural networks in PyTorch flexibly and transparently.
https://pytorch-ignite.ai
BSD 3-Clause "New" or "Revised" License
4.51k stars 612 forks source link

Possible improvements for Accuracy #1089

Open Yura52 opened 4 years ago

Yura52 commented 4 years ago

In full detail the feature request is described here, below is a quick recap.

There are two inconveniences I experience with the current interface of Accuracy.

1. Inconsistent input format for binary classification and multiclass problems

In the first case, Accuracy expects labels as input, whilst in the second case it expects probabilities/logits. I am a somewhat experienced Ignite user and I still get confused by this behavior.

2. No shortcuts for saying "I want to pass logits/probabilities as input"

In practice, I have never used Accuracy in the following manner for binary classification:

accuracy = Accuracy()

Instead, I always do one of the following:

accuracy = Accuracy(transform=lambda x: torch.round(torch.sigmoid(x)))
# either
accuracy = Accuracy(transform=lambda x: torch.round(x))

Suggested solution for both problems: let the user explicitly say in which form input will be passed:

import enum
class Accuracy(...):
    class Mode(enum.Enum):
        LABELS = enum.auto()
        PROBABILITIES = enum.auto()
        LOGITS = enum.auto()

    def __init__(self, mode=Mode.LABELS, ...):
        ...

The suggested interface can be also extended to support custom thresholds by adding the __call__ method to the Mode class.

sdesrozis commented 4 years ago

@WeirdKeksButtonK I really appreciate this API ! Thank you very much 👍🏻

vcarpani commented 3 years ago

Hello, I believe I could be assigned to this issue, since I have a PR for it

vfdev-5 commented 3 years ago

@vcarpani sure ! On Github we can not assign any user to the issue but only those from the project or who participated in the conversation here.

sallycaoyu commented 1 year ago

Hi everyone, I would like to try improving this issue.

vfdev-5 commented 1 year ago

Sure @sallycaoyu , please check also all related PRs and mentions.

sallycaoyu commented 1 year ago

For now, I am trying to finish implementing a binary_mode for binary and multilabel types to transform probabilities and logits into 0s and 1s as this PR has done. And if that works well, then I can consider how to add more flexibility to multiclass like issue #822 suggests.

Does that sound like a good plan? Or would your like Ignite to have a mode similar to what this issue suggests, i.e., mode in one of [binary, multiclass, multilabel] instead of one of [unchanged, probabilities, logits]? The former way will lead to more modifications to what we have right now, like removing is_multilabel and replacing it with mode for Accuracy, Precision, Recall, ClassificationReport, because now multilabel will be one option of mode.

vfdev-5 commented 1 year ago

@sallycaoyu thanks for the update! I think we can continue with mode as [unchanged, probabilities, logits, labels?]. Can you please sketch up with code snippets new API usage, emphasizing on "before" and "after". For example:

### before
acc = Accuracy()
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multiclass logits 

### after
acc = Accuracy(mode=Accuracy.LOGITS)
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multiclass logits : (N, C), (N, )

acc = Accuracy(mode=Accuracy.LABELS)
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multiclass labels  : (N, ), (N, )

etc

sallycaoyu commented 1 year ago

Sure! Suppose we have:

class Accuracy
       def __init__(
            self,
            output_transform: Callable = lambda x: x,
            is_multilabel: bool = False,
            device: Union[str, torch.device] = torch.device("cpu"),
            mode: str = 'unchanged',
            threshold: Union[float, int] = 0.5
        )
          .....

Then, for binary and multilabel data:

### before
acc = Accuracy()
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as binary labels (0s and 1s) : (N, ...), (N, ...), or (N, 1, ...), (N, ...)

acc = Accuracy(is_multilabel=True)
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multilabel labels (0s and 1s) : (N, C, ...), (N, C, ...)

### after
# LOGITS MODE
acc = Accuracy(mode='logits', threshold = 3.25)
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as binary logits (float in [-inf, inf]): (N, ...), (N, ...), or (N, 1, ...), (N, ...)

acc = Accuracy(mode='logits', threshold = 3.25, is_multilabel = True)
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multilabel logits (float in [-inf, inf]): (N, C, ...), (N, C, ...) 

# in this case, Accuracy will transform any value < 3.25 to be 0, value >= 3.25 to be 1
# if not passing a threshold, Accuracy will softmax the logits, and then transform any value < 0.5 to be 0, >= 0.5 to be 1

# PROBABILITIES MODE
acc = Accuracy(mode='probabilities', threshold = 0.6)
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as binary probabilities (float in [0, 1]): (N, ...), (N, ...), or (N, 1, ...), (N, ...)

acc = Accuracy(mode='probabilities', threshold = 0.6, is_multilabel = True)
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multilabel probabilities (float in [0, 1]): (N, C, ...), (N, C, ...)

# in this case, Accuracy will transform any value < 0.6 to be 0, value >= 0.6 to be 1
# if not passing a threshold, Accuracy will transform any value < 0.5 to be 0, >= 0.5 to be 1

# LABELS MODE
acc = Accuracy(mode='labels', threshold = 5)
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as binary labels (int in [0, inf]): (N, ...), (N, ...), or (N, 1, ...), (N, ...)

acc = Accuracy(mode='labels', threshold = 5, is_multilabel=True)
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multilabel labels (int in [0, inf]): (N, C, ...), (N, C, ...)

# in the case, Accuracy will transform any value < 5 to be 0, >= 5 to be 1
# must specify a threshold for labels mode

# UNCHANGED MODE
acc = Accuracy(mode='unchanged')
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as binary labels (0s and 1s): (N, ...), (N, ...), or (N, 1, ...), (N, ...)

acc = Accuracy(mode='unchanged’, is_multilabel=True)
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multilabel labels (0s and 1s): (N, C, ...), (N, ...)

# will work like before : raise an error when any value is not 0 or 1
# should not specify a threshold for unchanged mode 

For multiclass data:

### before
acc = Accuracy()
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multiclass logits : (N, C, ...), (N, ...)

### after: should not apply threshold to multiclass data
# LABELS MODE
acc = Accuracy(mode='labels')
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multiclass labels : (N, ...), (N, ...)
# conflict with _check_type(), since we use y.ndimension() + 1 == y_pred.ndimension() to check for multiclass data

# For now, the following multiclass modes will work like before (argmax):
# PROBABILITIES MODE
acc = Accuracy(mode='probabilities')
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multiclass probabilities : (N, C, ...), (N, ...)

# LOGITS MODE
acc = Accuracy(mode='logits')
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multiclass logits : (N, C, ...), (N, ...) 

# UNCHANGED MODE
acc = Accuracy(mode='unchanged')
acc.attach(evaluator, "acc")
# evaluator outputs y_pred, y as multiclass logits : (N, C, ...), (N, ...)
vfdev-5 commented 1 year ago

Thanks a lot for the snippet @sallycaoyu !

I have few thoughts about that:

What do you think ?

sallycaoyu commented 1 year ago

@vfdev-5 Thank you very much for the comments!

I agree that we can drop unchanged mode. And I also agree that output_transform can give users more flexibility than threshold, so threshold is not very necessary. Then by default, for: