Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
13.71k stars 1.26k forks source link

Inquiry: Integration of CoPE with FlashAttention #1062

Closed relic-yuexi closed 2 months ago

relic-yuexi commented 2 months ago

Dear Sir,

I hope this message finds you well. I am currently working on integrating the Contextual Position Encoding (CoPE) module with the FlashAttention mechanism, and I am reaching out to inquire if there is an existing or recommended method to facilitate this integration seamlessly.

Below is a brief overview of the CoPE module I am attempting to integrate:

class CoPE(nn.Module):
    def __init__(self, npos_max, head_dim):
        super().__init__()
        self.npos_max = npos_max
        self.pos_emb = nn.Parameter(torch.zeros(1, head_dim, npos_max))

    def forward(self, query, attn_logits):
        # Compute positions
        gates = torch.sigmoid(attn_logits)
        pos = gates.flip(-1).cumsum(dim=-1).flip(-1)
        pos = pos.clamp(max=self.npos_max - 1)
        # Interpolate from integer positions
        pos_ceil = pos.ceil().long()
        pos_floor = pos.floor().long()
        logits_int = torch.matmul(query, self.pos_emb)
        logits_ceil = logits_int.gather(-1, pos_ceil)
        logits_floor = logits_int.gather(-1, pos_floor)
        w = pos - pos_floor
        return logits_ceil * w + logits_floor * (1 - w)

class SelfAttn(nn.Module):
    def __init__(self, npos_max, head_dim):
        super().__init__()
        self.cope = CoPE(npos_max, head_dim)
        self.head_dim = head_dim

    def forward(self, query, key, val, mask):
        # q, k, v have dimensions batch x seq_len x head_dim
        attn_logits = torch.bmm(query, key.transpose(-1, -2))
        attn_logits = attn_logits / math.sqrt(self.head_dim)
        attn_logits += mask.log()
        attn_logits += self.cope(query, attn_logits)
        attn = torch.softmax(attn_logits, dim=-1)
        out = torch.bmm(attn, val)
        return out

I have found that there is a par called return_attn_probs=True, but i find that is mismatch.


from flash_attn.flash_attn_interface import flash_attn_func
import torch

batch_size = 1
seqlen = 2
nheads = 1
nheads_k = 1
headdim = 2

# Define simple tensors for q, k, and v
q = torch.tensor([[[[1.0, 2.0]]]], dtype=torch.float16).to("cuda")
k = torch.tensor([[[[3.0, 4.0]]]], dtype=torch.float16).to("cuda")
v = torch.tensor([[[[5.0, 6.0]]]], dtype=torch.float16).to("cuda")

out,attn_probs,_ = flash_attn_func(q, k, v, return_attn_probs=True)

print("Output shape:", out.shape)
print(out)
print("Attention probabilities shape:", attn_probs.shape)

# Print the output and attention probabilities for manual verification
print("Output:", out)
print("Attention probabilities:", attn_probs)

get Attention probabilities: tensor([[[7.7782]]], device='cuda:0')

import torch
import torch.nn.functional as F

# Given tensors
q = torch.tensor([[[[1.0, 2.0]]]], dtype=torch.float16).to("cuda")
k = torch.tensor([[[[3.0, 4.0]]]], dtype=torch.float16).to("cuda")
v = torch.tensor([[[[5.0, 6.0]]]], dtype=torch.float16).to("cuda")

# Step 1: Compute the dot product of q and k
qk = torch.matmul(q, k.transpose(-2, -1))
print("Step 1 - q * k^T:")
print(qk)

# Step 2: Scale the dot product
d_k = k.size(-1)
scaled_qk = qk / torch.sqrt(torch.tensor(d_k, dtype=torch.float16).to("cuda"))
print("\nStep 2 - Scaled q * k^T:")
print(scaled_qk)

# Step 3: Apply softmax
attention_weights = F.softmax(scaled_qk, dim=-1)
print("\nStep 3 - Softmax(Scaled q * k^T):")
print(attention_weights)

# Step 4: Multiply with v
output = torch.matmul(attention_weights, v)
print("\nStep 4 - Final Output:")
print(output)

get tensor([[[[7.7773]]]], device='cuda:0', dtype=torch.float16)

I am particularly interested in understanding if there are any existing utilities or guidelines within the FlashAttention framework that could simplify the process of incorporating CoPE. Additionally, any insights or suggestions on potential challenges or optimizations would be greatly appreciated.

Thank you for your time and assistance. I look forward to your response.

tridao commented 2 months ago

As the docstring says:

return_attn_probs: bool. Whether to return the attention probabilities. This option is for
           testing only. The returned probabilities are not guaranteed to be correct
           (they might not have the right scaling).
tridao commented 2 months ago

Tbh I don't see an easy way to get CoPE to go fast.

relic-yuexi commented 2 months ago

I originally thought that just speeding up the qk operation would be enough. 😂