Closed cw-tan closed 1 week ago
Hi! thanks for your contribution!, great first issue!
It looks like adding gathered_result[torch.distributed.get_rank(group)] = result
has worked for me so far, i.e.
def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> List[Tensor]:
gathered_result = [torch.zeros_like(result) for _ in range(world_size)]
torch.distributed.all_gather(gathered_result, result, group)
gathered_result[torch.distributed.get_rank(group)] = result
return gathered_result
found here.
I have a setup with
torch.Lightning
where I'm using customtorchmetrics.Metric
as loss function contributions. Now I want to be able to do it withddp
by settingdist_sync_on_step=True
, but the gradients are not propagated during theall_gather
. All I want is for the tensor on the current process to have its autograd graph kept for the backward pass after the syncing operations. I've only just began looking into distributed stuff intorch
, so I'm not experienced in these matters. But following theforward()
call ofMetric
(at each training batch step), it then calls_forward_reduce_state_update()
, which calls thecompute()
function wrapped by_wrap_compute()
, which would dosync()
, which finally calls_sync_dist()
. And it looks like the syncing usestorchmetrics.utilities.distributed.gather_all_tensors
.I just wanted to ask if it is possible to achieve what I want by modiyfing
_simple_gather_all_tensors
(here)?_simple_gather_all_tensors
presented here for reference.I'm guessing that
result
still carries the autograd graph. My naive hope is that we can just updategathered_result
with the inputresult
(carrying the autograd graph) to achieve the desired effect.For context, my use case is such that batches can have very inhomogeneous
numel
s, so each device could have error tensors with very differentnumel
s such that taking a mean ofMeanSquaredError
s may not be ideal. Ideally, if the syncing holds the autograd graph, the per-step loss would be the "true" metric as per its definition and the gradients would be consistent with that definition (so syncing is done once for for each loss metric contribution, and once for the backward at each training step, I think).Thank you!