mindspore-lab / mindone

one for all, Optimal generator with No Exception
https://mindspore-lab.github.io/mindone/
Apache License 2.0
338 stars 63 forks source link

Fix sdv1.5 + FlashAttention bug on 910B #310

Closed wtomin closed 5 months ago

wtomin commented 5 months ago

As I investigated on 910B, the reason why sdv1.5 + FA generates noisy images is that ms.nn.layer.flash_attention (ms version is 2.2.10.20231124) behaves wrongly when head dimension equals to 160. Note that the head dimension of sdv1.5 is one of [40, 80, 160, 320].

There are three versions of solutions:

When running inference with text_to_image.py on 910B, I gradually increase the batch size to trigger the OOM error. Here are the results: solution OOM batch size Not OOM batch size
v1 300 284
v2 366 360
v3 384 372

From the perspective of memory efficiency, I suggest using v3.

However, this maybe a temporary fix. If mindspore fix this bug in the future, we should remove this fix.

@Songyuanwei , please help to check whether solution v3 has any effect on the sdv1.5 finetuning experiment. Thank you!

zhtmike commented 5 months ago

I suggest to submit a issue on mindspore gitee, since this bug introduce some heavily impact on using flash attention.

wtomin commented 5 months ago

I suggest to submit a issue on mindspore gitee, since this bug introduce some heavily impact on using flash attention.

I agree with you. I will approach someone in the MindSpore team and report this bug directly, but this temporary fix is still needed as we need to use FA to meet some ddls.

SamitHuang commented 5 months ago

Is the training result also fine now?

wtomin commented 5 months ago

Is the training result also fine now?

The finetuning experiment results are working in progress.