Lightning-AI / torchmetrics

Torchmetrics - Machine learning metrics for distributed, scalable PyTorch applications.
https://lightning.ai/docs/torchmetrics/
Apache License 2.0
2.08k stars 398 forks source link

RMSE value is not equal to the square root of MSE #879

Closed martinwhl closed 2 years ago

martinwhl commented 2 years ago

🐛 Bug

I'm currently working on a multivariate time series forecasting problem, and RMSE is one of the evaluation metrics. After evaluating the model on the test set, I find that the RMSE value might be a little bit lower than expected, and then I find that the square of RMSE value is not equal to the MSE value.

To Reproduce

Here's my simple implementation of RMSE:

class RootMeanSquaredError(torchmetrics.MeanSquaredError):
    def compute(self):
        return torch.sqrt(super().compute())

Code sample

Here's the link to the repo that reproduces the problem, with data attached: rmse_test.

Expected behavior

RMSE should be approximately equal to the square root of MSE.

Environment

github-actions[bot] commented 2 years ago

Hi! thanks for your contribution!, great first issue!

SkafteNicki commented 2 years ago

I am investigating this issue, but I am confident that this is not an issue with torchmetrics but with lightning. If I after running the script try to compute:

res = model.metrics.compute() # {'MeanSquaredError': tensor(841.2437), 'RootMeanSquaredError': tensor(29.0042)}
res["MeanSquaredError"] == res["RootMeanSquaredError"] ** 2 # tensor(True)

i get the correct result, meaning that the internal metric computations are right. @martinwhl Also I want to note that the MeanSquaredError metric has a squared argument that you can set to False if you want the RMSE metric.

SkafteNicki commented 2 years ago

Another observation, if the value instead is logged as:

    def test_step(self, batch, batch_idx):
        x, y = batch
        pred = self(x)
        y *= self.max_value
        pred *= self.max_value
        self.metrics.update(pred, y)

    def test_epoch_end(self, outputs) -> None:
        self.log_dict(self.metrics.compute())

then it also works.

SkafteNicki commented 2 years ago

Okay final update. The issue is this line:

self.log_dict(self.metrics(pred, y), batch_size=x.size(0))

the problem is that self.metrics(pred, y) will return the a dict of metrics on the current batch, thus the input will be a dict of tensors. When lightning detects that the input to self.log is a tensor it will automatically average the value over all the batches, which give a slightly biased value due to the square root not being a commutative operator (sqrt(a+b) != sqrt(a) + sqrt(b)).

The solution is to log the metric object instead, which will take care of the accumulation and making sure that everything is done correctly. Like this

self.metrics(pred, y)
self.log_dict(self.metrics)

which is also the way described in the documentation https://torchmetrics.readthedocs.io/en/latest/pages/lightning.html.

Sorry for the confusion.

martinwhl commented 2 years ago

Thanks a lot for your help!