facebookresearch / xformers

Hackable and optimized Transformers building blocks, supporting a composable construction.
https://facebookresearch.github.io/xformers/
Other
8.5k stars 606 forks source link

Consulting a question about the mat_mult_with_mask kernel function. #733

Open shenshanf opened 1 year ago

shenshanf commented 1 year ago

In the kernel function implemented by your project mat_mult_with_mask, you use the CUDA_1D_KERNEL_LOOP to create a loop, which may cause threads with different ids to access and modify the same non-zero element at the same time, I confused if this loop will cause performance degradation. Would it be better to directly replace CUDA_1D_KERNEL_LOOP with if tid < nnz

danthe3rd commented 1 year ago

Hi, Where is this function mat_mult_with_mask you are mentioning?

shenshanf commented 1 year ago

Sorry I didn't describe the problem clearly.

In this file: matmul.cu, line 33

you create a for loop by macro definition CUDA_1D_KERNEL_LOOP for every thread id, i.e. id=(blockIdx.x * blockDim.x) + threadIdx, it may cause the different thread id to access and modify the same non-zero element in sparse mask at the same time , and cause performance degradation. Would it be better to directly replace CUDA_1D_KERNEL_LOOP with if id < nnz

danthe3rd commented 1 year ago

I'm not sure about this. This is a code we haven't touched/used in a while, so it's very possible that it could be optimized further. Cc @fmassa who wrote this code

Out of curiosity, what do you need this code for?

shenshanf commented 1 year ago

Thanks,I have understood the code you wrote. when nnz is more than the number of blocks, some blocks will handle multiple nnz element.

shenshanf commented 1 year ago

I need this code for implementing graph attention network algorithm with sparse attention .

shenshanf commented 1 year ago

I have implemented a matmul_with_mask function that can compute sparse gradients based on your code, which is attached in the file. I implemented it using numba cuda. matmul_mask.zip