pytorch / torcheval

A library that contains a rich collection of performant PyTorch model metrics, a simple interface to create new metrics, a toolkit to facilitate metric computation in distributed training and tools for PyTorch model evaluations.
https://pytorch.org/torcheval
Other
211 stars 46 forks source link

Fix warning in aggregation.mean #187

Closed bobakfb closed 8 months ago

bobakfb commented 8 months ago

Summary: This diff fixes the incorrect warning when running mean.compute() when the mean is exactly 0.

Instead of checking for the weighted sum of elements to be 0, we instead check for the total sum of weights to be zero (meaning that the average can be 0 without error, but we throw a warning when dividing by zero)

We also update the error message to reflect that the issue is no weight has been accumulated, since it is possible to call this function with only 0 weights.

Addresses: https://github.com/pytorch/torcheval/issues/185

Reviewed By: JKSenthil

Differential Revision: D50806243

facebook-github-bot commented 8 months ago

This pull request was exported from Phabricator. Differential Revision: D50806243