NVIDIA / TensorRT

NVIDIA® TensorRT™ is an SDK for high-performance deep learning inference on NVIDIA GPUs. This repository contains the open source components of TensorRT.
https://developer.nvidia.com/tensorrt
Apache License 2.0
10.84k stars 2.14k forks source link

State of affairs for NestedTensor (NJT) inference? #4234

Open vadimkantorov opened 3 weeks ago

vadimkantorov commented 3 weeks ago

PyTorch now has some support for representing varlen sequences. It is supported to some extent by HF:

This is useful e.g. for saving compute on padding tokens for BERT inference. Does TRT has kernels for such NJT sdpa ops? (and can they be executed via CUDA graphs?) If so, how to benefit from it? Is there an example?

Thank you!

lix19937 commented 2 weeks ago

In my opinion, trt has no such op(njt), you can custom write it.

If you want to use cudagraph, you need set the max-len of sequence (use the fixed address of sequences to build the static graph), and set the min, opt, max shape for this input.

vadimkantorov commented 2 weeks ago

A key component of NJT support for SDPA are block-diagonal masks. Does TRT have support/examples for block-diagonal attn masks?

Because one would want to have proper FlashAttention kernels in this setup, otherwise the speedups likely may not be realized...

poweiw commented 2 weeks ago

@zhenhuaw-me Can you take a look?

vadimkantorov commented 2 weeks ago

The relevant documentation in Triton Inference Server on ragged batch support: https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/user_guide/ragged_batching.html

So it would be good to have (end-to-end, starting from a PyTorch model, then export and configuring TRT engine file with trex visualizations) examples of optimized attention modules for transformer inference on such varlen sequences in TRT...

vadimkantorov commented 2 weeks ago

These kernels appear available from older FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/docs/bert_guide.md#model-architecture or in https://github.com/bytedance/effective_transformer

It would be good to upstream these "EffectiveTransformer kernels with TensorRT" given that FasterTransformer EOL'd