RulinShao / LightSeq

Official repository for LightSeq: Sequence Level Parallelism for Distributed Training of Long Context Transformers
179 stars 8 forks source link

varlen dist_attn kernel #2

Closed RulinShao closed 10 months ago

RulinShao commented 10 months ago

Support padding mask in lightseq kernel (draft tested with simple cases). The inputs will be unpadded before being fed into the kernel. Currently only support unpadded length to be a multiple of flash attn block size, will support all unpadded lengths in next update.