cellarium-ai / cellarium-ml

Distributed single-cell data analysis.
BSD 3-Clause "New" or "Revised" License
11 stars 2 forks source link

Fix `GatherLayer.backward` and make the test more stringent #184

Closed ordabayevy closed 4 months ago

ordabayevy commented 4 months ago

In GatherLayer.backward each gradient in *grads needs to be reduced across all ranks. Here is my initial fix that worked for ContrastiveMLP:

    @staticmethod
    def backward(ctx, *grads) -> torch.Tensor:
        new_grads = []
        for grad in grads:
            grad = grad.contiguous()
            dist.all_reduce(grad, op=dist.ReduceOp.SUM)
            new_grads.append(grad)
        grad_out = new_grads[dist.get_rank()]
        return grad_out

Then we updated it to the current version (which select grads[rank] and then does reduction which then returns averaged grad_out for all ranks) but didn't test it using ContrastiveMLP model. The test in test_gather.py has passed because it was symmetric and not stringent enough. I have updated the test to be asymmetric across ranks which catches the bug in the current implementation. See discussion at https://github.com/lightly-ai/lightly/pull/1531 where I also borrowed stacking implementation instead of using for loop.