pytorch / torchrec

Pytorch domain library for recommendation systems
https://pytorch.org/torchrec/
BSD 3-Clause "New" or "Revised" License
1.95k stars 441 forks source link

torchrec change for dynamic embedding #2533

Open kanghui0204 opened 3 weeks ago

kanghui0204 commented 3 weeks ago

Hi TorchREC experts,

We would like to try incorporating NVIDIA HKV into the existing TorchREC workflow to extend TorchREC's capabilities for model-parallel dynamic embedding.

We aim to integrate HKV dynamic embedding as a new type of embedding table into the TorchREC workflow. To avoid disrupting the original TorchREC code, we have designed some code for registering new embedding tables, which will help us and other users to better register a customized embedding table into the TorchREC workflow. Our modifications mainly target the following two parts:

  1. Registering a new customized compute table during the creation of the embedding table and lookup, and accepting its customized parameters.

  2. Since the range of indices for dynamic embedding is unlimited, we need the input distribution to perform round-robin distribution.(Our current PR serves as a reference. For example, in the input dist section, we have only modified the RW code. However, it is necessary to support all sharding types, such as TWRW)

Our code is based on v0.7, and it can be easily migrated to the latest code. We are initiating this PR as a reference for further discussions with you. We hope to support a high-performance dynamic embedding feature.

facebook-github-bot commented 3 weeks ago

Hi @kanghui0204!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks!

dstaay-fb commented 2 weeks ago

Thanks for proposal;

RE [1]: it would help to put together a toy example of what your doing potentially; so we can see how you intend to use this api; ideally to point you could create a few multi-gpu tests (with appropriate mocking if needed etc).

RE [2]: So its not well documented, but we actually can support round robin based RW sharding today; its utilized in ZCH workflows (bucketization strategy is % world_size). Basically if you pass in RwSparseFeaturesDist(.., feature_hash_dim = [0,....0]) this will trigger this logic. This calls into FBGEMM block_bucketization kernels. Coincidently just added logic in this area, take a look at tests in PR: https://github.com/pytorch/torchrec/pull/2538 - specifically the case we set input_hash_size=0 on ZCH modules for full behavior (albeit a different use case).

kanghui0204 commented 2 weeks ago

Hi @dstaay-fb thank you very much for quickly reply!

RE1: I will prepare a example for you as a reference as soon as possible. RE2: Sorry , I didn't find the test for input_hash_size=0 on ZCH modules in PR2538,

Do you mean that setting the hash size of each table to 0 will make the block_bucketize_sparse_featuresin FBGEMM switch from contiguous block partitioning to round-robin partitioning? It looks like we need to modify the information of sharding_infosinput to BaseRwEmbeddingSharding(https://github.com/dstaay-fb/torchrec/blob/export-D62483238/torchrec/distributed/sharding/rw_sharding.py#L115), is that correct?