Quantized Attention that achieves speedups of 2.1x and 2.7x compared to FlashAttention2 and xformers, respectively, without lossing end-to-end metrics across various models.
BSD 3-Clause "New" or "Revised" License
400
stars
17
forks
source link
q_kernel_per_block_int8 error in distributed settings. #25
I try to applied sageattention in DeepSpeed-Ulysses. However it encounters the following error in q_kernel_per_block_int8.
Could you please provide a torch version for q_kernel_per_block_int8? so I can debug my program.
rank1: q_int8, q_scale, k_int8, k_scale = per_block_int8(q, k)
rank1: File "/cfs/fjr2/SageAttention/sageattention/quant_per_block.py", line 63, in per_block_int8
rank1: File "/home/pjz/miniconda3/envs/fjr/lib/python3.10/site-packages/triton/runtime/jit.py", line 345, in rank1: return lambda *args, kwargs: self.run(grid=grid, warmup=False, *args, *kwargs)
rank1: File "/home/pjz/miniconda3/envs/fjr/lib/python3.10/site-packages/triton/runtime/jit.py", line 691, in run
rank1: kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
rank1: File "/home/pjz/miniconda3/envs/fjr/lib/python3.10/site-packages/triton/backends/nvidia/driver.py", line 365, in callrank1: self.launch(args, kwargs)
rank1: ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
https://github.com/feifeibear/long-context-attention/pull/90
I try to applied sageattention in DeepSpeed-Ulysses. However it encounters the following error in q_kernel_per_block_int8. Could you please provide a torch version for q_kernel_per_block_int8? so I can debug my program.
rank1: q_int8, q_scale, k_int8, k_scale = per_block_int8(q, k) rank1: File "/cfs/fjr2/SageAttention/sageattention/quant_per_block.py", line 63, in per_block_int8
rank1: File "/home/pjz/miniconda3/envs/fjr/lib/python3.10/site-packages/triton/runtime/jit.py", line 345, in
rank1: return lambda *args, kwargs: self.run(grid=grid, warmup=False, *args, *kwargs)
rank1: File "/home/pjz/miniconda3/envs/fjr/lib/python3.10/site-packages/triton/runtime/jit.py", line 691, in run
rank1: kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
rank1: File "/home/pjz/miniconda3/envs/fjr/lib/python3.10/site-packages/triton/backends/nvidia/driver.py", line 365, in call
rank1: self.launch(args, kwargs)
rank1: ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)