Open vadimkantorov opened 3 weeks ago
In my opinion, trt has no such op(njt), you can custom write it.
If you want to use cudagraph, you need set the max-len of sequence (use the fixed address of sequences to build the static graph), and set the min, opt, max shape for this input.
A key component of NJT support for SDPA are block-diagonal masks. Does TRT have support/examples for block-diagonal attn masks?
Because one would want to have proper FlashAttention kernels in this setup, otherwise the speedups likely may not be realized...
@zhenhuaw-me Can you take a look?
The relevant documentation in Triton Inference Server on ragged batch support: https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/user_guide/ragged_batching.html
So it would be good to have (end-to-end, starting from a PyTorch model, then export and configuring TRT engine file with trex visualizations) examples of optimized attention modules for transformer inference on such varlen sequences in TRT...
These kernels appear available from older FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/docs/bert_guide.md#model-architecture or in https://github.com/bytedance/effective_transformer
It would be good to upstream these "EffectiveTransformer kernels with TensorRT" given that FasterTransformer EOL'd
PyTorch now has some support for representing varlen sequences. It is supported to some extent by HF:
This is useful e.g. for saving compute on padding tokens for BERT inference. Does TRT has kernels for such NJT sdpa ops? (and can they be executed via CUDA graphs?) If so, how to benefit from it? Is there an example?
Thank you!