Open Kademo15 opened 2 days ago
Yeah in my quickdif app I have a setup that uses try/catch instead of a dim check to support any flash attention library
import torch
from torch import Tensor
from typing import Callable
def _patch_sdpa(
patch_func: Callable[[Tensor, Tensor, Tensor, Tensor | None, float, bool, float | None], Tensor],
):
"""(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None)"""
torch_sdpa = torch.nn.functional.scaled_dot_product_attention
def sdpa_hijack_flash(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
try:
return patch_func(query, key, value, attn_mask, dropout_p, is_causal, scale)
except Exception:
hidden_states = torch_sdpa(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
)
return hidden_states
torch.nn.functional.scaled_dot_product_attention = sdpa_hijack_flash
try:
from flash_attn import flash_attn_func
def sdpa_hijack_flash(q, k, v, m, p, c, s):
assert m is None
result = flash_attn_func(
q=q.transpose(1, 2),
k=k.transpose(1, 2),
v=v.transpose(1, 2),
dropout_p=p,
softmax_scale=s if s else q.shape[-1] ** (-0.5),
causal=c,
)
assert isinstance(result, Tensor)
return result.transpose(1, 2)
_patch_sdpa(sdpa_hijack_flash)
print("# # #\nPatched SDPA with Flash Attention\n# # #")
except ImportError as e:
print(f"# # #\nCould not load Flash Attention for hijack:\n{e}\n# # #")
I don't see why it won't also work for ComfyUI, but I haven't used Comfy in a while.
So if I would just slap that code in your node that could work ?
Probably. Replace the existing block. Leave the node mappings at the bottom comfy needs those
Is it possible to support this version of flash attention, i find that by just editing your code and replacing the 128 with 512 i get an error with stable diffusion 1.5. Is it possible to write the code to dynamically switch the head dim based on the model that's running. Furthermore I don't have enough knowledge to know if just replacing 128 would make it support the 512 version. If you could look into that would be great.