leejet / stable-diffusion.cpp

Stable Diffusion and Flux in pure C/C++
MIT License
3.27k stars 274 forks source link

flash attention leads to error #74

Closed rayrayraykk closed 10 months ago

rayrayraykk commented 11 months ago

I try to use ggml_flash_attn to accelerate the process, so I replace ggml_mul_mat in cross-attention in UNET in stable-diffusion.cpp:

...
#if 1
                struct ggml_tensor * kqv = ggml_flash_attn(ctx, q, k, v, true);
#else
                struct ggml_tensor* kq = ggml_mul_mat(ctx, k, q);  // [N * n_head, h * w, h * w]
                // kq = ggml_diag_mask_inf_inplace(ctx, kq, 0);
                kq = ggml_soft_max_inplace(ctx, kq);

                struct ggml_tensor* kqv = ggml_mul_mat(ctx, v, kq);  // [N * n_head, h * w, d_head]
#endif
...

But it leads to an error. Looks like the max_position = 2, N = 64, and const int64_t P = nek1 - N; which is less than 0. Can someone help me? Great thx!

rayrayraykk commented 11 months ago

@ggerganov @Green-Sky @leejet Looking forward to your help! :)

rayrayraykk commented 10 months ago

After I change to ggml_flash_attn(ctx, q, k, v, false); and add:

    if (masked){
        GGML_ASSERT(P >= 0);
    }

The program works fine, but got images that make absolutely no sense... I'm really confused :(

Edit: ggml_flash_attn scales kq, after comment line below, everything works fine.

q = ggml_scale_inplace(ctx, q, ggml_new_f32(ctx, 1.0f / sqrt((float)d_head)));
Green-Sky commented 10 months ago

Did you observe any speed improvement?

rayrayraykk commented 10 months ago

Did you observe any speed improvement?

There is a speed improvement in CLBlast. But other backends are not obvious.

FSSRepo commented 10 months ago

@rayrayraykk I will see if after finishing my pull request for adding the CUDA backend, I will work on using flash attention v2 and improving the conv2d algorithm with FFT. Additionally, there is a new recent paper that allows using tensor cores to accelerate convolutions in CUDA. The truth is that I don't have enough knowledge to interpret equations from papers into code, which makes it somewhat difficult for me to implement things.