mindspore-lab / mindone

one for all, Optimal generator with No Exception
https://mindspore-lab.github.io/mindone/
Apache License 2.0
338 stars 63 forks source link

FlashAttention: Fix Compatibility #295

Closed wtomin closed 5 months ago

wtomin commented 5 months ago

Another solution to fix the compatibility problem of flash-attention layer. It might be a better solution compared with #291 because using nn.layers.flash_attention supports head_dim that is not divisible by 16 by padding it to 16*N.

I have tested text_to_image.py on:

  1. 910A (ms2.1.0) 1.1 sdv2.0 :heavy_check_mark: graph mode :heavy_check_mark: pynative mode 1.2 sdv1.5 :heavy_check_mark: graph mode :heavy_check_mark: pynative mode sdv1.5 succeed because I restrict the fa_max_head_dim to 128 instead of 256 on 910A, otherwise there will be OOM error; This value can be passed through v1-inference.yaml.

2 910B (ms 2.2.10.2023.1124) 2.1 sdv1.5 :heavy_check_mark: graph mode :heavy_check_mark: pynative mode 2.2 sdv2.0 :heavy_check_mark: graph mode :x: pynative mode

out = self.flash_attention(
    q.to(ms.float16), k.to(ms.float16), v.to(ms.float16), mask.to(self.fa_mask_dtype)
)
...
Runtime Error: Index 3 is out of bounds 3.

I think it is a weird bug, because graph mode has no such an error.

hadipash commented 5 months ago

Close #291 then to avoid merging as it passed reviews?

wtomin commented 5 months ago

Close #291 then to avoid merging as it passed reviews?

Good suggestion!

SamitHuang commented 5 months ago

is the generated image on sd1.5 + 910b good?

wtomin commented 5 months ago

is the generated image on sd1.5 + 910b good?

Yes, the quality is good.

wtomin commented 5 months ago

Good analysis of FA! But I find it amusing that FA can cause OOM, even though it is supposed to be more memory-efficient compared to vanilla attention.

The OOM caused by FA is not like the OOM caused by a large batch size. As indicated by flash attention in mindspore library, on 910A, it restricts to head dimension to less than 304, otherwise it will cause UB OOM. Although not 100% sure, I guess UB indicates some ultra high-speed NPU memory that is used for exchange data.