Closed cyanguwa closed 2 weeks ago
/te-ci pytorch
/te-ci jax
/te-ci pytorch
/te-ci pytorch
/te-ci pytorch
/te-ci jax
/te-ci paddle
/te-ci paddle
Hi @cyanguwa, I remembered that we have 3 API changes are pending
Do you have any estimate time for item 1. 2.? Should we also change them in this PR?
Hi @zlsh80826 ,
Yes, this PR is just focused on item 3. I wanted to get this done first so there is no API change between v1.8 and v1.9. I'm still evaluating the benefits/changes for items 1 and 2, but they will cause breaking API changes anyway. There is no urgency to it (well, not as much as item 3, given the code freeze coming up).
Thanks for reviewing.
Description
This PR reduces the THD offset tensors from four (
seq_offsets_q
,seq_offsets_k
,seq_offsets_v
,seq_offsets_o
) to two (cu_seqlens_q_padded
,cu_seqlens_kv_padded
).Before this PR, for THD_THD_THD layout, users need to calculate these four tensors:
With this PR, users only need to provide two tensors,
cu_seqlens_q_padded
andcu_seqlens_kv_padded
, which are easier to understand and utilize correctly.An example of the difference between
cu_seqlens
andcu_seqlens_padded
is, for a batch[a, PAD, b, b, c, PAD, PAD, d, d]
, we have 4 sequences,cu_seqlens = [0, 1, 3, 4, 6]
, andcu_seqlens_padded= [0, 2, 4, 7, 9]
.Type of change
Changes
Please list the changes introduced in this PR:
Checklist: