pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
82.88k stars 22.34k forks source link

In PyTorch v2.1.2, within fmha_api.cpp, the mha_fwd() function redundantly checks cu_seqlens_k #122595

Open 1274085042 opened 6 months ago

1274085042 commented 6 months ago

🐛 Describe the bug

https://github.com/pytorch/pytorch/blob/v2.1.2/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp#L247-L248
cc @jbschlosser @bhosmer @cpuhrsch @erichan1 @drisspg @mikaylagawarecki

Versions

torch version == v2.1.2

drisspg commented 6 months ago

If you wanted to update this to check that cu_seq_len_q is contiguous that would be sweet and I can review