Open vnkc1 opened 1 day ago
I tried to imitate your educational coding style hehe
Here's a pure Pytorch implementation of Flash Attention, hope you like it @karpathy
def flash_attention(Q, K, V, is_causal=True, BLOCK_SIZE:int=64): NEG_INFINITY = -1e10 EPS = 1e-10 B, nh, T, H = Q.shape scale = H ** -0.5 assert Q.shape == K.shape and Q.shape == V.shape, "Some of Q,K,V are misshapen!" # TODO: Allow small sequences assert T >= BLOCK_SIZE, "For small sequences, use standard attention!" # initialize buffers outputs = torch.zeros_like(Q) maximums = torch.full((B, nh, T, 1), fill_value=NEG_INFINITY) denominators = torch.full((B, nh, T, 1), fill_value=EPS) # chop up matrices Q_blocks, K_blocks, V_blocks = map( lambda x: torch.split(x, BLOCK_SIZE, dim=2), (Q, K, V) ) O_blocks, M_blocks, D_blocks = map( lambda x: list(torch.split(x, BLOCK_SIZE, dim=2)), (outputs, maximums, denominators) ) # helper variables for causal mask positions = torch.arange(0, T) K_index_blocks = torch.split(positions[None, :], BLOCK_SIZE, dim=1) Q_index_blocks = torch.split(positions[:, None], BLOCK_SIZE, dim=0) for k_index in range(len(K_blocks)): k_block = K_blocks[k_index] v_block = V_blocks[k_index] for q_index in range(len(Q_blocks)): # create causal mask causal_mask = K_index_blocks[k_index] <= Q_index_blocks[q_index] # calculate masked attention scores q_block = Q_blocks[q_index] attn = q_block @ k_block.permute(0, 1, 3, 2) * scale attn = torch.where(causal_mask, attn, NEG_INFINITY) # calculate new maximum attention score per query vector old_maximum = M_blocks[q_index] local_maximum, _ = torch.max(attn, dim=-1, keepdim=True) new_maximum = torch.maximum(old_maximum, local_maximum) # Now that maximum is known, we can safely exponentiate attn scores attn = torch.exp(attn-new_maximum) # Adjust and update the softmax denominator. denominator_scaler = torch.exp(old_maximum-new_maximum) denominator_update = torch.sum(attn, dim=-1, keepdim=True) old_denominator = D_blocks[q_index]*denominator_scaler new_denominator = old_denominator + denominator_update # Adjust and update the output of attention. output_scaler = old_denominator / new_denominator output_update = attn @ v_block / new_denominator old_output = O_blocks[q_index]*output_scaler new_output = old_output + output_update # Store new maximums, new denominators and new attention output. M_blocks[q_index] = new_maximum D_blocks[q_index] = new_denominator O_blocks[q_index] = new_output # Patch together attention output into a single (B, nh, T, H) vector. return torch.cat(O_blocks, dim=2)
Inspired by Shreyansh's implementation.
I tried to imitate your educational coding style hehe
Here's a pure Pytorch implementation of Flash Attention, hope you like it @karpathy