pytorch / torchdistx

Torch Distributed Experimental
BSD 3-Clause "New" or "Revised" License
116 stars 31 forks source link

Patch for GossipGraD algorithm #56

Closed aovladi closed 2 years ago

aovladi commented 2 years ago

What does this PR do? Please describe: Currently GossipGraD algorithm increases state.iteration every time comm_hook is called and later changes topology based on this state.iteration. This is incorrect, because during the same backward comm_hook can be called multiple times. Current patch addresses this issue.

Now GossipGraD requires a num_modules parameter, which is used to calculate proper time when to switch topology.

Appropriate unittests are added. New experimental results show general improvement in performance.

Does your PR introduce any breaking changes? If yes, please list them: List of all backwards-incompatible API changes.

Check list: