pytorch / torchrec

Pytorch domain library for recommendation systems
https://pytorch.org/torchrec/
BSD 3-Clause "New" or "Revised" License
1.95k stars 441 forks source link

Optimize performance of embeddings sharding #2568

Closed iamzainhuda closed 3 days ago

iamzainhuda commented 3 days ago

Summary: While working on TTFB it was observed that sharding of embededed bag is taking significant time and is one of the biggest contributors to TTFB especially on large jobs. After strobelight data analysis it was clear that most of the time is spent on all_gather collective calls. Currently we construct sharded tensor one by one calling collective to exchange metadata which is not very efficient. More optimal approach is letting all the ranks build their portion of metadata for all tensors and exchange it with single collective call, thus significantly reducing overhead and improve performance.

Testing on 256 ranks showed ~15x speed up.

Differential Revision: D65489998

facebook-github-bot commented 3 days ago

This pull request was exported from Phabricator. Differential Revision: D65489998

facebook-github-bot commented 3 days ago

This pull request was exported from Phabricator. Differential Revision: D65489998