NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
1.61k stars 256 forks source link

[C/PyTorch] Simplify THD offset tensors #927

Closed cyanguwa closed 2 weeks ago

cyanguwa commented 3 weeks ago

Description

This PR reduces the THD offset tensors from four (seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o) to two (cu_seqlens_q_padded, cu_seqlens_kv_padded).

Before this PR, for THD_THD_THD layout, users need to calculate these four tensors:

seq_offsets_q =           config.num_heads * config.head_dim * cu_seqlens_q_padded
seq_offsets_k = config.num_gqa_groups * config.head_dim * cu_seqlens_kv_padded
seq_offsets_v = config.num_gqa_groups * config.head_dim * cu_seqlens_kv_padded
seq_offsets_o =            config.num_heads * config.head_dim * cu_seqlens_q_padded

With this PR, users only need to provide two tensors, cu_seqlens_q_padded and cu_seqlens_kv_padded, which are easier to understand and utilize correctly.

An example of the difference between cu_seqlens and cu_seqlens_padded is, for a batch [a, PAD, b, b, c, PAD, PAD, d, d], we have 4 sequences, cu_seqlens = [0, 1, 3, 4, 6], and cu_seqlens_padded= [0, 2, 4, 7, 9].

Type of change

Changes

Please list the changes introduced in this PR:

Checklist:

cyanguwa commented 2 weeks ago

/te-ci pytorch

cyanguwa commented 2 weeks ago

/te-ci jax

cyanguwa commented 2 weeks ago

/te-ci pytorch

cyanguwa commented 2 weeks ago

/te-ci pytorch

cyanguwa commented 2 weeks ago

/te-ci pytorch

cyanguwa commented 2 weeks ago

/te-ci jax

cyanguwa commented 2 weeks ago

/te-ci paddle

cyanguwa commented 2 weeks ago

/te-ci paddle

zlsh80826 commented 2 weeks ago

Hi @cyanguwa, I remembered that we have 3 API changes are pending

  1. Support separate q/kv acutal_seqlen, offsets for qkvpacked API
  2. Support seqlens to avoid cu_seqlens -> seqlens -> cu_seqlens
  3. Simplify THD format APIs (this PR)

Do you have any estimate time for item 1. 2.? Should we also change them in this PR?

cyanguwa commented 2 weeks ago

Hi @zlsh80826 ,

Yes, this PR is just focused on item 3. I wanted to get this done first so there is no API change between v1.8 and v1.9. I'm still evaluating the benefits/changes for items 1 and 2, but they will cause breaking API changes anyway. There is no urgency to it (well, not as much as item 3, given the code freeze coming up).

Thanks for reviewing.