As depicted above, we are extending TorchRec input data type from KJT (KeyedJaggedTensor) to TD (TensorDict)
Basically we can support TensorDict in both eager mode and distributed (sharded) mode: Input (Union[KJT, TD]) ==> EBC ==> Output (KT)
In eager mode, we directly call td_to_kjt in the forward function to convert TD to KJT.
In distributed mode, we do the conversion inside the ShardedEmbeddingBagCollection, specifically in the input_dist, where the input sparse features are prepared (permuted) for the KJTAllToAll communication.
In the KJT scenario, the input KJT would be permuted (and partially duplicated in some cases), followed by the KJTAllToAll communication.
While in the TD scenario, the input TD will directly be converted to the permuted KJT ready for the following KJTAllToAll communication.
ref: D63436011
Details
td_to_kjt implemented in python, which has cpu perf regression. But it's not on the training critical path so it has a minimal impact on the overall training QPS (see test plan benchmark results)
Currently only support EBC use case
WARNING: TensorDict does NOT support weighted jagged tensor, Nor variable batch_size neither.
NOTE: All the following comparisons are between the KJT.permute in the KJT input scenario and the TD-KJT conversion in the TD input scenario.
Both KJT.permute and TD-KJT conversion are correctly marked in the TrainPipelineBase traces
TD-KJT conversion has more real executions in CPU, but the heavy-lifting computation is in GPU, which is delayed/blocked by the backward pass of the previous batch. GPU runtime has a small difference ~10%.
{F1949366822}
For the Copy-Batch-To-GPU part, TD has more fragmented HtoD comms while KJT has a single contiguous HtoD comm
Runtime-wise they are similar ~10%
{F1949374305}
In the most commonly used TrainPipelineSparseDist, where the Copy-Batch-To-GPU and the cpu runtime are not on the critical path, we do observe very similar training QPS in the pipeline benchmark ~1%
{F1949390271}
increased data size, GPU runtime is 4x
{F1949386106}
Conclusion
[Enablement] With this approach (replacing the KJT permute with TD-KJT conversion), the EBC can now take TensorDict as the module input in both single-GPU and multi-GPU (sharded) scenarios, tested with TrainPipelineBase, TrainPipelineSparseDist, TrainPipelineSemiSync, and TrainPipelinePrefetch.
[Performance] The TD host-to-device data transfer might not necessarily be a concern/blocker for the most commonly used train pipeline (TrainPipelineSparseDist).
[Feature Support] In order to become production-ready, the TensorDict needs to (1) integrate the KJT.weights data, and (2) to support the variable batch size, which are almost used in all the production models.
[Improvement] There are two major operations we can improve: (1) move TensorDict from host to device, and (2) convert TD to KJT. Currently they are both in the vanilla state. Since we are not sure how the real traces would be like with production models, we can't tell if these improvements are needed/helpful.
Summary:
Documents
Context
Input (Union[KJT, TD]) ==> EBC ==> Output (KT)
td_to_kjt
in the forward function to convert TD to KJT.ShardedEmbeddingBagCollection
, specifically in theinput_dist
, where the input sparse features are prepared (permuted) for theKJTAllToAll
communication.KJTAllToAll
communication. While in the TD scenario, the input TD will directly be converted to the permuted KJT ready for the followingKJTAllToAll
communication.Details
td_to_kjt
implemented in python, which has cpu perf regression. But it's not on the training critical path so it has a minimal impact on the overall training QPS (see test plan benchmark results)TensorDict
does NOT support weighted jagged tensor, Nor variable batch_size neither. NOTE: All the following comparisons are between theKJT.permute
in the KJT input scenario and theTD-KJT conversion
in the TD input scenario.KJT.permute
andTD-KJT conversion
are correctly marked in theTrainPipelineBase
tracesTD-KJT conversion
has more real executions in CPU, but the heavy-lifting computation is in GPU, which is delayed/blocked by the backward pass of the previous batch. GPU runtime has a small difference ~10%. {F1949366822}Copy-Batch-To-GPU
part, TD has more fragmentedHtoD
comms while KJT has a single contiguousHtoD
comm Runtime-wise they are similar ~10% {F1949374305}TrainPipelineSparseDist
, where theCopy-Batch-To-GPU
and the cpu runtime are not on the critical path, we do observe very similar training QPS in the pipeline benchmark ~1% {F1949390271}Conclusion
KJT permute
withTD-KJT conversion
), the EBC can now takeTensorDict
as the module input in both single-GPU and multi-GPU (sharded) scenarios, tested with TrainPipelineBase, TrainPipelineSparseDist, TrainPipelineSemiSync, and TrainPipelinePrefetch.KJT.weights
data, and (2) to support the variable batch size, which are almost used in all the production models.Differential Revision: D65103519