Lightning-AI / torchmetrics

Machine learning metrics for distributed, scalable PyTorch applications.
https://lightning.ai/docs/torchmetrics/
Apache License 2.0
2.11k stars 403 forks source link

F1 score seems not to match harmonic mean of precision and recall scores #256

Closed jacanchaplais closed 3 years ago

jacanchaplais commented 3 years ago

🐛 Bug

Hi there. I am using the modular forms of BinnedPrecisionRecallCurve and F1 to calculate precision, recall, and F1. I noticed that F1 went down while both precision and recall went up, so I manually checked that F1 was consistent with the harmonic mean, and it seemed to be mostly too high.

TensorBoard logged precision, recall and f1, averaged over validation sets per epoch

eg. for the last datapoints on my precision (0.1296), recall (0.77197) and f1 (0.3508) graphs

import typing

Vector = typing.List[float]
def harmonic_mean(in_vals: Vector) -> float:
     num_vals = len(in_vals)
     recip_sum = 0.0
     for val in in_vals:
         recip_sum += val ** -1
     return num_vals / recip_sum

harmonic_mean([0.1296, 0.71197])
>>> 0.21928374823247024

which is significantly smaller than the value reported by the metric.

Am I missing something? I am running a tuning algorithm to optimise the F1 score, so it is crucial that it is correct.

Code sample

You can see my implementation of the metrics here (from line 70 to EOF), it is relatively simple and (I think) follows the recommendations in the documentation.

https://github.com/jacanchaplais/cluster_gnn/blob/f-gnn/src/cluster_gnn/models/gnn.py#L70

        self.train_ACC = torchmetrics.Accuracy(threshold=infer_thresh)
        self.train_F1 = torchmetrics.F1(
                num_classes=1, threshold=infer_thresh)
        self.val_ACC = torchmetrics.Accuracy(threshold=infer_thresh)
        self.val_F1 = torchmetrics.F1(
                num_classes=1, threshold=infer_thresh)
        self.val_PR = torchmetrics.BinnedPrecisionRecallCurve(
                num_classes=1, num_thresholds=5)

    def forward(self, data, sigmoid=True):
        node_attrs, edge_attrs = data.x, data.edge_attr
        edge_attrs, node_attrs = self.encode(node_attrs, data.edge_index,
                                             edge_attrs)
        edge_attrs, node_attrs = self.message(node_attrs, data.edge_index,
                                              edge_attrs)
        pred = self.classify(edge_attrs)
        if sigmoid:
            pred = torch.sigmoid(pred)
        return pred

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
                self.parameters(),
                lr=self.lr,
                weight_decay=self.decay
            )
        return optimizer

    def _train_av_loss(self, outputs):
        return torch.stack([x['loss'] for x in outputs]).mean()

    def _val_av_loss(self, losses):
        return torch.stack(losses).mean()

    def training_step(self, batch, batch_idx):
        edge_pred = self(batch, sigmoid=False)
        loss = self.criterion(edge_pred, batch.y.view(-1, 1))
        return {'loss': loss,
                'preds': torch.sigmoid(edge_pred),
                'target': batch.y.view(-1, 1).int()}

    def training_step_end(self, outputs):
        self.train_ACC(outputs['preds'], outputs['target'])
        self.train_F1(outputs['preds'], outputs['target'])
        self.log('loss/train_step', outputs['loss'], on_step=True)
        self.log('acc/train_step', self.train_ACC, on_step=True)
        self.log('f1/train_step', self.train_F1, on_step=True)
        return outputs['loss']

    def training_epoch_end(self, outputs):
        self.log('loss/train_epoch', self._train_av_loss(outputs))
        self.log('acc/train_epoch', self.train_ACC.compute())
        self.log('f1/train_epoch', self.train_F1.compute())

    def validation_step(self, batch, batch_idx):
        edge_pred = self(batch, sigmoid=False)
        loss = self.criterion(edge_pred, batch.y.view(-1, 1))
        return {'loss': loss,
                'preds': torch.sigmoid(edge_pred),
                'target': batch.y.view(-1, 1).int()}

    def validation_step_end(self, outputs):
        self.val_ACC(outputs['preds'], outputs['target'])
        self.val_F1(outputs['preds'], outputs['target'])
        self.val_PR(outputs['preds'], outputs['target'])
        self.log('loss/val_step', outputs['loss'], on_step=True)
        self.log('acc/val_step', self.val_ACC, on_step=True)
        return outputs['loss']

    def validation_epoch_end(self, outputs):
        self.log('loss/val_epoch', self._val_av_loss(outputs))
        self.log('acc/val_epoch', self.val_ACC.compute())
        self.log('f1/val_epoch', self.val_F1.compute())
        prec, recall, thresh = self.val_PR.compute()
        for i, t in enumerate(thresh):
            self.log(f'prec_val_epoch/thresh_{t:.3f}', prec[i])
            self.log(f'recall_val_epoch/thresh_{t:.3f}', recall[i])

Environment

Installed via conda with following environment.yml

name: ptg

channels:
  - pytorch
  - conda-forge
  - defaults

dependencies:
  - cudatoolkit=10.2 # version for cluster compatibility
  - pip
  - python=3.8 # version for PyTorch Geometric compatibility
  - tqdm # provides progress bars to Python loops
  - jupyter
  - pytorch=1.7 # version for cluster compatibility
  - torchaudio # unnecessary for production (but useful for examples)
  - torchvision # unnecessary for production (but useful for examples)
  - networkx # required by PyTorch Geometric
  - torchmetrics # modular ML metrics for PyTorch
  - pytorch-lightning # parallelisation
  - ray-tune # tuning hyperparams
  - hyperopt # tuning hyperparams
  - tensorboard # visualising metrics
  - scipy
  - numpy
  - numba # JIT compilation and other performance boosts
  - pandas # DataFrames
  - h5py # interface to HDF5 data files
  - python-louvain
  - scikit-learn
  - requests
  - rdflib
  - googledrivedownloader # required by PyTorch Geometric
  - ase
  - jinja2
  - zenodo_get # downloading public datasets

Automatically determined version numbers:

I am also using PyTorch Geometric installed via pip, although this shouldn't be relevant as the metrics are never exposed it.

github-actions[bot] commented 3 years ago

Hi! thanks for your contribution!, great first issue!

jacanchaplais commented 3 years ago

Actually, thinking about it, there may well be a difference between the average of the harmonic means over a validation set, and the harmonic mean of the averages over that set. I think this may explain it, apologies. I will do some maths and close the issue if it turns out I am being a moron.

jacanchaplais commented 3 years ago

I am a moron confirmed. Apologies for wasting your time.


import numpy as np

prec_recall_set = np.random.rand(100, 2)
all_f1s = np.array([harmonic_mean(prec_recall) for prec_recall in prec_recall_set])
mean_f1 = all_f1s.mean()

mean_prec_recall = np.mean(prec_recall_set, axis=0)
f1_of_means = harmonic_mean(mean_prec_recall)

print(mean_f1)
>>> 0.4254433011584456

print(f1_of_means)
>>> 0.5107380180397513
Borda commented 3 years ago

@jacanchaplais no need, thx for your concern in the implementation correctness :rabbit:

teichert commented 3 years ago

I think this point is actually well worth clarifying in the documentation, and it might even be worth adding a macro-pre (or, alternatively, pre-harmonic-macro or early-macro) averaging method that corresponds to the OP's original interpretation.

From this note and my own experience/investigations, both interpretation of macro-averaged F1 exist in the literature. And although it seems that torchmetrics, sklearn, and tensorflow are all consistent on their interpretation, having them both available in torchmetrics could ease comparison with prior work and highlight the distinction so that whatever is chosen is properly described.

As to the source for the OP's original interpretation (and my own!)?: The book commonly credited with introducing the F-measure discusses micro- and macro- averaging schemes in the context of precision-recall curves prior to introducing the equivalent of F1 (see this article for a derivation connecting the two; the original definition in this article is mostly reprinted in the book verbatim but does not mention either type of averaging). After describing micro- and macro-averaged precision and recall, the book then defines the "effectiveness" measure, F, in terms of precision and recall (without reference to any type of averaging) so that might be the source of that interpretation. (It would be interesting to ask Dr. van Rijsbergen what his recommendation would have been.)