Lightning-AI / torchmetrics

Machine learning metrics for distributed, scalable PyTorch applications.
https://lightning.ai/docs/torchmetrics/
Apache License 2.0
2.13k stars 408 forks source link

Autograd with DDP #2745

Closed cw-tan closed 1 week ago

cw-tan commented 1 month ago

I have a setup with torch.Lightning where I'm using custom torchmetrics.Metric as loss function contributions. Now I want to be able to do it with ddp by setting dist_sync_on_step=True, but the gradients are not propagated during the all_gather. All I want is for the tensor on the current process to have its autograd graph kept for the backward pass after the syncing operations. I've only just began looking into distributed stuff in torch, so I'm not experienced in these matters. But following the forward() call of Metric (at each training batch step), it then calls _forward_reduce_state_update(), which calls the compute() function wrapped by _wrap_compute(), which would do sync(), which finally calls _sync_dist(). And it looks like the syncing uses torchmetrics.utilities.distributed.gather_all_tensors.

I just wanted to ask if it is possible to achieve what I want by modiyfing _simple_gather_all_tensors (here)? _simple_gather_all_tensors presented here for reference.

def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> List[Tensor]:
    gathered_result = [torch.zeros_like(result) for _ in range(world_size)]
    torch.distributed.all_gather(gathered_result, result, group)
    return gathered_result

I'm guessing that result still carries the autograd graph. My naive hope is that we can just update gathered_result with the input result (carrying the autograd graph) to achieve the desired effect.

For context, my use case is such that batches can have very inhomogeneous numels, so each device could have error tensors with very different numels such that taking a mean of MeanSquaredErrors may not be ideal. Ideally, if the syncing holds the autograd graph, the per-step loss would be the "true" metric as per its definition and the gradients would be consistent with that definition (so syncing is done once for for each loss metric contribution, and once for the backward at each training step, I think).

Thank you!

github-actions[bot] commented 1 month ago

Hi! thanks for your contribution!, great first issue!

cw-tan commented 1 month ago

It looks like adding gathered_result[torch.distributed.get_rank(group)] = result has worked for me so far, i.e.

def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> List[Tensor]:
    gathered_result = [torch.zeros_like(result) for _ in range(world_size)]
    torch.distributed.all_gather(gathered_result, result, group)
    gathered_result[torch.distributed.get_rank(group)] = result
    return gathered_result

found here.