Beinsezii / comfyui-amd-go-fast

Simple monkeypatch to boost AMD Navi 3 GPUs
MIT License
16 stars 0 forks source link

Support bigger head dims #3

Open Kademo15 opened 2 days ago

Kademo15 commented 2 days ago

Is it possible to support this version of flash attention, i find that by just editing your code and replacing the 128 with 512 i get an error with stable diffusion 1.5. Is it possible to write the code to dynamically switch the head dim based on the model that's running. Furthermore I don't have enough knowledge to know if just replacing 128 would make it support the 512 version. If you could look into that would be great.

Beinsezii commented 1 day ago

Yeah in my quickdif app I have a setup that uses try/catch instead of a dim check to support any flash attention library

import torch
from torch import Tensor
from typing import Callable

def _patch_sdpa(
    patch_func: Callable[[Tensor, Tensor, Tensor, Tensor | None, float, bool, float | None], Tensor],
):
    """(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None)"""

    torch_sdpa = torch.nn.functional.scaled_dot_product_attention

    def sdpa_hijack_flash(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
        try:
            return patch_func(query, key, value, attn_mask, dropout_p, is_causal, scale)
        except Exception:
            hidden_states = torch_sdpa(
                query=query,
                key=key,
                value=value,
                attn_mask=attn_mask,
                dropout_p=dropout_p,
                is_causal=is_causal,
                scale=scale,
            )
        return hidden_states

    torch.nn.functional.scaled_dot_product_attention = sdpa_hijack_flash

try:
    from flash_attn import flash_attn_func

    def sdpa_hijack_flash(q, k, v, m, p, c, s):
        assert m is None
        result = flash_attn_func(
            q=q.transpose(1, 2),
            k=k.transpose(1, 2),
            v=v.transpose(1, 2),
            dropout_p=p,
            softmax_scale=s if s else q.shape[-1] ** (-0.5),
            causal=c,
        )
        assert isinstance(result, Tensor)
        return result.transpose(1, 2)

    _patch_sdpa(sdpa_hijack_flash)
    print("# # #\nPatched SDPA with Flash Attention\n# # #")
except ImportError as e:
    print(f"# # #\nCould not load Flash Attention for hijack:\n{e}\n# # #")

I don't see why it won't also work for ComfyUI, but I haven't used Comfy in a while.

Kademo15 commented 1 day ago

So if I would just slap that code in your node that could work ?

Beinsezii commented 1 day ago

Probably. Replace the existing block. Leave the node mappings at the bottom comfy needs those