Lightning-AI / torchmetrics

Machine learning metrics for distributed, scalable PyTorch applications.
https://lightning.ai/docs/torchmetrics/
Apache License 2.0
2.15k stars 409 forks source link

BinaryAUROC hangs when calling metric.compute on multi-node multi-cards #2758

Closed gouchangjiang closed 2 months ago

gouchangjiang commented 2 months ago

🐛 Bug

torchmetrics.classification.BinaryAUROC.compute() hangs when running the program on multi-node multi-card.

To Reproduce

Launch the following code snippet by torchrun

import os

import torch
import torch.distributed as dist
from torchmetrics.classification import BinaryAUROC

def calculate_auc(metric, local_rank):
    preds = torch.randn([1024, 512]).cuda(local_rank)
    target = torch.randint(low=0, high=2, size=(1024, 512)).cuda(local_rank)
    current_auc = metric(preds = preds, target = target)
    return current_auc

if __name__ == '__main__':
    dist.init_process_group(backend="nccl")
    metric = BinaryAUROC(thresholds=600)
    local_rank = int(os.environ['LOCAL_RANK'])
    metric.to(f"cuda:{local_rank}")

    for iter in range(600):
        current_auc = calculate_auc(metric, local_rank)
        if local_rank == 0:
            print(f'=== step: {iter}, auc: {current_auc} ===')
            if iter == 100:
                print(f'=== average auc over 300 steps: {metric.compute()} ===')

Expected behavior

The program prints the auc until it hangs at the 100-th iteration.

Environment

github-actions[bot] commented 2 months ago

Hi! thanks for your contribution!, great first issue!

Borda commented 2 months ago

@gouchangjiang was it resoled? 🤔

gouchangjiang commented 2 months ago
metric.compute()

Yeah, I just solved the issue after posting it. The problem is that metric.compute() has only been called by rank 0. All ranks should call it.