Open xiexbing opened 3 months ago
in the forward pass, in the table wise sharding, when pooling is executed?
The pooling is done in FBGEMM_GPU, not in torchrec itself. For example: https://pytorch.org/FBGEMM/fbgemm_gpu-python-api/table_batched_embedding_ops.html
The backward key sortion and gradient reduction are also done inside fbgemm_gpu.
Can you point out for the classes in fbgemm_gpu that covers the backward sorting and gradient aggregation?
On Mon, Jul 29, 2024 at 10:28 PM Junzhang @.***> wrote:
The backward key sortion and gradient reduction are also done inside fbgemm_gpu.
— Reply to this email directly, view it on GitHub https://github.com/pytorch/torchrec/issues/2257#issuecomment-2257495221 or unsubscribe https://github.com/notifications/unsubscribe-auth/AQGDI2BECYBK6SFHVWHFP2LZO4QBDBFKMF2HI4TJMJ2XIZLTSOBKK5TBNR2WLJDUOJ2WLJDOMFWWLO3UNBZGKYLEL5YGC4TUNFRWS4DBNZ2F6YLDORUXM2LUPGBKK5TBNR2WLJDUOJ2WLJDOMFWWLLTXMF2GG2C7MFRXI2LWNF2HTAVFOZQWY5LFUVUXG43VMWSG4YLNMWVXI2DSMVQWIX3UPFYGLLDTOVRGUZLDORPXI6LQMWWES43TOVSUG33NNVSW45FGORXXA2LDOOJIFJDUPFYGLKTSMVYG643JORXXE6NFOZQWY5LFVEZTQNJUGA4DMNRTQKSHI6LQMWSWS43TOVS2K5TBNR2WLKRSGQZTMOBSGI3DGNFHORZGSZ3HMVZKMY3SMVQXIZI . You are receiving this email because you authored the thread.
Triage notifications on the go with GitHub Mobile for iOS https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675 or Android https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub .
It's written in cpp, specifically, cuda source code.
templatized
file and not readable, you can build fbgemm_gpu on your own and check the generated files. ) For example This template file defines the autograd function that embraces forward and backward entry. A generated file example from the template file.
sorry, a followup question for the forward pass communication and pooling. I profiled the forward pass with both nsight and torch profiler. I see the all2all calls at python level, and the low level NCCL calls as SendReceive for all2all, but I didn't see any calls map to the pooling. Does the pooling actually happen within SendReceive? or it is after SendReceive and within all2all? or it is after all2all, but just not appropriate profiled.
On Mon, Jul 29, 2024 at 10:25 PM Junzhang @.***> wrote:
in the forward pass, in the table wise sharding, when pooling is executed?
The pooling is done in FBGEMM_GPU, not in torchrec itself. For example: https://pytorch.org/FBGEMM/fbgemm_gpu-python-api/table_batched_embedding_ops.html
— Reply to this email directly, view it on GitHub https://github.com/pytorch/torchrec/issues/2257#issuecomment-2257491798 or unsubscribe https://github.com/notifications/unsubscribe-auth/AQGDI2B4CJ5KDLAKGJ5S5U3ZO4PUBBFKMF2HI4TJMJ2XIZLTSOBKK5TBNR2WLJDUOJ2WLJDOMFWWLO3UNBZGKYLEL5YGC4TUNFRWS4DBNZ2F6YLDORUXM2LUPGBKK5TBNR2WLJDUOJ2WLJDOMFWWLLTXMF2GG2C7MFRXI2LWNF2HTAVFOZQWY5LFUVUXG43VMWSG4YLNMWVXI2DSMVQWIX3UPFYGLLDTOVRGUZLDORPXI6LQMWWES43TOVSUG33NNVSW45FGORXXA2LDOOJIFJDUPFYGLKTSMVYG643JORXXE6NFOZQWY5LFVEZTQNJUGA4DMNRTQKSHI6LQMWSWS43TOVS2K5TBNR2WLKRSGQZTMOBSGI3DGNFHORZGSZ3HMVZKMY3SMVQXIZI . You are receiving this email because you authored the thread.
Triage notifications on the go with GitHub Mobile for iOS https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675 or Android https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub .
another question is: in the forward pass, before all2all, how to remove the duplicated embedding indices per table, and then run all2all with input (unique indices) and output (the corresponding embedding vectors). please point out the code. thanks.
On Tue, Jul 30, 2024 at 10:00 PM Bing Xie @.***> wrote:
sorry, a followup question for the forward pass communication and pooling. I profiled the forward pass with both nsight and torch profiler. I see the all2all calls at python level, and the low level NCCL calls as SendReceive for all2all, but I didn't see any calls map to the pooling. Does the pooling actually happen within SendReceive? or it is after SendReceive and within all2all? or it is after all2all, but just not appropriate profiled.
On Mon, Jul 29, 2024 at 10:25 PM Junzhang @.***> wrote:
in the forward pass, in the table wise sharding, when pooling is executed?
The pooling is done in FBGEMM_GPU, not in torchrec itself. For example: https://pytorch.org/FBGEMM/fbgemm_gpu-python-api/table_batched_embedding_ops.html
— Reply to this email directly, view it on GitHub https://github.com/pytorch/torchrec/issues/2257#issuecomment-2257491798 or unsubscribe https://github.com/notifications/unsubscribe-auth/AQGDI2B4CJ5KDLAKGJ5S5U3ZO4PUBBFKMF2HI4TJMJ2XIZLTSOBKK5TBNR2WLJDUOJ2WLJDOMFWWLO3UNBZGKYLEL5YGC4TUNFRWS4DBNZ2F6YLDORUXM2LUPGBKK5TBNR2WLJDUOJ2WLJDOMFWWLLTXMF2GG2C7MFRXI2LWNF2HTAVFOZQWY5LFUVUXG43VMWSG4YLNMWVXI2DSMVQWIX3UPFYGLLDTOVRGUZLDORPXI6LQMWWES43TOVSUG33NNVSW45FGORXXA2LDOOJIFJDUPFYGLKTSMVYG643JORXXE6NFOZQWY5LFVEZTQNJUGA4DMNRTQKSHI6LQMWSWS43TOVS2K5TBNR2WLKRSGQZTMOBSGI3DGNFHORZGSZ3HMVZKMY3SMVQXIZI . You are receiving this email because you authored the thread.
Triage notifications on the go with GitHub Mobile for iOS https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675 or Android https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub .
sorry, a followup question for the forward pass communication and pooling. I profiled the forward pass with both nsight and torch profiler. I see the all2all calls at python level, and the low level NCCL calls as SendReceive for all2all, but I didn't see any calls map to the pooling. Does the pooling actually happen within SendReceive? or it is after SendReceive and within all2all? or it is after all2all, but just not appropriate profiled.
The pooling is done before all2all. Maybe this figure is clearer
The first all2all is for lookup keys which is fed into fbgemm and fbgemm will do pooling for you. The second all2all performs all2all on pooled embedding. So the pooling happens actually in CUDA kernels.
Thanks for the explanation. It helps a lot. Let me dig deeper to some details: 1, can I assume the first all2all actually communicate about the KJT per batch? so after this all2all, each GPU with a shard of an embedding table actually will launch a fbgemm kernel to do the lookup and pooling. and then use the 2nd all2all to send the pooled embedding vectors back to the KJT batch owner. If this assumption is correct, for row wise parallelism, all2all is still used in forward pass, if a sample has the indices across multiple shards (means the embedding vectors stored on different GPUs), how to do pooling?
On Tue, Jul 30, 2024 at 11:23 PM Junzhang @.***> wrote:
sorry, a followup question for the forward pass communication and pooling. I profiled the forward pass with both nsight and torch profiler. I see the all2all calls at python level, and the low level NCCL calls as SendReceive for all2all, but I didn't see any calls map to the pooling. Does the pooling actually happen within SendReceive? or it is after SendReceive and within all2all? or it is after all2all, but just not appropriate profiled.
The pooling is done before all2all. Maybe this figure is clearer image.png (view on web) https://github.com/user-attachments/assets/ed784d80-0b2f-4067-b68a-fa62a29cd891
The first all2all is for lookup keys which is fed into fbgemm and fbgemm will do pooling for you. The second all2all performs all2all on pooled embedding. So the pooling happens actually in CUDA kernels. image.png (view on web) https://github.com/user-attachments/assets/80c28ad6-7199-4cf6-aa27-bf30734b7cb9
— Reply to this email directly, view it on GitHub https://github.com/pytorch/torchrec/issues/2257#issuecomment-2259761683 or unsubscribe https://github.com/notifications/unsubscribe-auth/AQGDI2DIFK34EIS3PZ5XQKLZPB7HRBFKMF2HI4TJMJ2XIZLTSOBKK5TBNR2WLJDUOJ2WLJDOMFWWLO3UNBZGKYLEL5YGC4TUNFRWS4DBNZ2F6YLDORUXM2LUPGBKK5TBNR2WLJDUOJ2WLJDOMFWWLLTXMF2GG2C7MFRXI2LWNF2HTAVFOZQWY5LFUVUXG43VMWSG4YLNMWVXI2DSMVQWIX3UPFYGLLDTOVRGUZLDORPXI6LQMWWES43TOVSUG33NNVSW45FGORXXA2LDOOJIFJDUPFYGLKTSMVYG643JORXXE6NFOZQWY5LFVEZTQNJUGA4DMNRTQKSHI6LQMWSWS43TOVS2K5TBNR2WLKRSGQZTMOBSGI3DGNFHORZGSZ3HMVZKMY3SMVQXIZI . You are receiving this email because you authored the thread.
Triage notifications on the go with GitHub Mobile for iOS https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675 or Android https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub .
One correction: For RW, there is no the 2nd all2all, instead, it should be a ReduceScatter. (Table-wise should contain all2all)
1, can I assume the first all2all actually communicate about the KJT per batch?
AFAIK, Yes. But be noted that a KJT all2all usually is composed of several tensor all2all.
if a sample has the indices across multiple shards (means the embedding vectors stored on different GPUs), how to do pooling?
There should be a collective. For rw sharding, it could be a reduce scatter.
why not have only one all2all to let each KJT batch owner have the individual embedding vectors (based on their keys) and let them launch cuda kernel to do pooling locally.
Initially, each rank has its dp input. The first all2all is used for sending keys to the corresponding sharding for lookup. I don't think we can skip it. As I clarified, the 2nd should be a RS for RW. The local lookup result from fbgemm is a dense tensor which includes all bags, so it's ready for RS. If the fbgemm does not do pooling and return a jagged tensor, then we need a jagged tensor all2all + local reduce.
before all2all, how to remove the duplicated embedding indices per table.
For embeddingbag
, I think there's no dedpulication. But for embedding
, there is a deduplication in input_dist. You need to explicitly set use_index_dedup=True for EmbeddingCollectionSharder
to enable unique.
thanks for the clarification! very helpful.
sorry, forget to ask some details on the collective communication operators. Just for confirmation: in the forward pass, for the communication about embedding indices, there are multiple all2all calls, each call for a jagged tensor (e.g. one call for values, one call for lengths, one call for lengths per key). In both forward and backward pass, for the communications for embedding vectors, there is only one call for all batches across all keys and tables. if the 2nd statement is correct, when the system scales up, the data size of of the call will increase n^2 (n is the number of ranks in the system). Why not cut the single call to multiple calls with each call communicate about 64MB (for best network bandwidth efficiency).
in the forward pass, in the table wise sharding, when pooling is executed? is it after alltoall communication? and executed on trainer local? where can I see the exact code in torchrec code base?
in the backward pass, in the table wise sharding, when sorting and aggregation is executed? can you please point the lines in code base.