Lightning-AI / torchmetrics

Torchmetrics - Machine learning metrics for distributed, scalable PyTorch applications.
https://lightning.ai/docs/torchmetrics/
Apache License 2.0
2.07k stars 395 forks source link

Support metrics reduction outside of DDP #2063

Open ytang137 opened 1 year ago

ytang137 commented 1 year ago

🚀 Feature

Support metrics reduction outside of DDP.

Motivation

When used within DDP, torchmetrics objects support automatic syncing and reduction across ranks. However, there doesn't seem to be support for reduction outside DDP. This will be a good feature to have because it allows using torchmetrics for distributed evaluation using frameworks other than DDP.

Pitch

Let's say we are computing metrics on a large dataset. Each worker receives a shard of the dataset and computes the metric for the shard and we collect the metrics object from all workers: metrics = [metric_0, metric_1, metric_2, metric_3, ...]. To compute the final metric across the entire dataset, we need a mechanism to reduce the metrics. By looking at the torchmetrics/src/torchmetrics/metric.py, I see one potential solution:

metric_reduced = MetricType()  # Same type as the metrics in metrics list
for metric in metrics:
    metric_reduced._reduce_states(metric.metric_state)

# Compute final metric
final_metric = metric_reduced.compute()

However, this approach relies on the private member function _reduce_states. It would be great if torchmetrics can offer a all_reduce(metrics: Iterable[MetricType]) -> MetricType function that achieves the same functionality.

Alternatives

Additional context

github-actions[bot] commented 1 year ago

Hi! thanks for your contribution!, great first issue!

SkafteNicki commented 1 year ago

cc: @justusschock

justusschock commented 11 months ago

Hey @ytang137 and sorry for the late reply.

A metric takes the following input kwargs on initialization:

These are used for all the syncs we run internally. Can you give an example where these functions wouldn't be sufficient? In general, a reduction is always applied after an all_gather when all states have been synced already.

ytang137 commented 11 months ago

Hi @justusschock , thanks for getting back to me. The use case I had in mind is to use TorchMetrics out of the context of DDP, or even out of PyTorch all together, where concepts such as process_group don't apply.

Consider this example: we have inference results from a large dataset saved in a database, and we want to compute metrics grouped by dates and also report metrics aggregated over the entire dataset in the end. One approach would be to use Dask or Modin to calculate the metrics distributedly: dataset.groupby("date").apply(my_metric_function). The my_metric_function here returns a TorchMetric object containing states updated by the data from a specific date. The groupby operation therefore returns a series of TorchMetric objects. It would be very nice to have an all_reduce(metrics: Iterable[Metric]) -> Metric function that performs reduction to compute the metric over the entire dataset.

Is it currently possible to reduce metric this way? Correct me if I'm wrong - it seems that the concepts of process_group, dist_sync_fn, and distributed_available_fn don't apply in this example, or maybe these parameters can be used in certain way to support this use case? Thanks.

ytang137 commented 1 month ago

Hi @justusschock @SkafteNicki @Borda @lantiga , I'm hoping to revive this thread. Below is a concrete example to help get the discussion going.

[Update 7/12]: I realized what I'm asking for is essentially the def merge_state(self: TSelf, metrics: Iterable[TSelf]) -> TSelf method on torcheval's Metric base class.

Example

We are looking to use torchmetrics in a data-parallel framework other than DDP, such as ray. Imagine that we are using ray to parallelly process shards of a dataset, and each ray task returns a metric object, whose state has been updated by its corresponding shard.

Using BinaryFBetaScore as an example:

from torchmetrics.classification import BinaryFBetaScore
from torch import tensor

metric_shard_1 = BinaryFBetaScore(1.0, threshold=0.5, compute_with_cache=False)
metric_shard_2 = BinaryFBetaScore(1.0, threshold=0.5, compute_with_cache=False)
metric_shard_3 = BinaryFBetaScore(1.0, threshold=0.5, compute_with_cache=False)
metric_all = BinaryFBetaScore(1.0, threshold=0.5, compute_with_cache=False)
target_1 = tensor([0, 1, 0, 1, 0, 1])
preds_1 = tensor([0, 1, 1, 0, 0, 1])
target_2 = tensor([0, 1, 0, 1, 0, 1])
preds_2 = tensor([0, 0, 0, 0, 0, 1])
target_3 = tensor([1, 1, 0, 1, 0, 1])
preds_3 = tensor([0, 1, 1, 1, 1, 1])

# Each line below happens in a separate `ray` task
metric_shard_1.update(preds_1, target_1)
metric_shard_2.update(preds_2, target_2)
metric_shard_3.update(preds_3, target_3)

# When computation on all shards finishes, `ray` gathers a list of metric objects:
metric_list = [metric_shard_1, metric_shard_2, metric_shard_3]

Naturally we'd like to reduce the metrics in metric_list to compute a metric on the entire dataset.

My hacky approach

I found this hacky approach to work:

reduced_metric = metric_list[0].clone()
for metric in metric_list[1:]:
  reduced_metric._reduce_states(metric.metric_state)

# This gives the same result as a single metric updated on the entire dataset
metric_all.update(preds_1, target_1)
metric_all.update(preds_2, target_2)
metric_all.update(preds_3, target_3)
assert metric_all.compute().item() == reduced_metric.compute().item()

Obviously this is hacky because:

The ask

I sincerely feel providing this reduction-out-of-DDP functionality makes torchmetrics more widely applicable. Thank you very much.