Open GindaChen opened 3 months ago
Hey took some liberties with the interpretation of your problem, need to read through the links you shared.
My gut reaction is something like this:
import torch
from torch.nn.attention.flex_attention import flex_attention
B, H = 2, 3
max_tokens = 128
higlighted_tokens = torch.zeros(B, H, 2)
scales = torch.rand(B, H, 1)
higlighted_tokens[0, 0] = torch.tensor([2, 5])
higlighted_tokens[0, 1] = torch.tensor([3, 7])
higlighted_tokens[0, 2] = torch.tensor([1, 4])
def attention_steering(score, b, h, q_idx, kv_idx):
start_stop_pairs = higlighted_tokens[b, h]
valid_q_token = (q_idx >= start_stop_pairs[0]) & (q_idx < start_stop_pairs[1])
valid_kv_token = (kv_idx >= start_stop_pairs[0]) & (kv_idx < start_stop_pairs[1])
scale_ = scales[b, h, torch.tensor(0)]
return torch.where(valid_q_token & valid_kv_token, score + scale_, score)
query = torch.rand(B, H, max_tokens, 8)
key = torch.rand(B, H, max_tokens, 8)
value = torch.rand(B, H, max_tokens, 8)
# This won't work, but just conceptually showing the possibility:
output = flex_attention(query, key, value, score_mod=attention_steering)
Assumptions the tokens to be steered lie within a contiguous span. There is only 1 scale to add per each sequence.
This impl, essentially gets the beginning and end of that span, check to see if the current q_idx and kv_idx lie within that span, if see we add the signular scalar otherwise we return the scale as is.
@drisspg Thanks for helping out! I modify the script a bit and found a way that works (with multiple spans, but 1 scale factor per sequence):
import torch
from torch.nn.attention.flex_attention import flex_attention
B, H = 8, 8
max_tokens = 1024
higlighted_tokens = torch.zeros(B, H, max_tokens, dtype=torch.bool)
scales = torch.rand(B, H, 1)
highlighted_token_idx = {
# (batch_id, head_id) -> (token_idxs, scalar_factor)
(0, 2): ([1, 3, 5], 1.2),
(1, 3): ([2, 4], 1.5),
}
# Setup highlighted tokens and scaling factor
for (b, h), (token_idxs, scalar_factor) in highlighted_token_idx.items():
higlighted_tokens[b, h, token_idxs] = True
scales[b, h, 0] = scalar_factor
def attention_steering(score, b, h, q_idx, kv_idx):
start_stop_pairs = higlighted_tokens[b, h]
valid_q_token = start_stop_pairs[q_idx]
valid_kv_token = start_stop_pairs[kv_idx]
scale_ = scales[b, h, torch.tensor(0)]
return torch.where(valid_q_token & valid_kv_token, score + scale_, score)
def attention_steering_scale_down(score, b, h, q_idx, kv_idx):
start_stop_pairs = higlighted_tokens[b, h]
valid_q_token = ~start_stop_pairs[q_idx]
valid_kv_token = ~start_stop_pairs[kv_idx]
scale_ = scales[b, h, torch.tensor(0)]
return torch.where(valid_q_token & valid_kv_token, score - scale_, score)
query = torch.rand(B, H, max_tokens, 8)
key = torch.rand(B, H, max_tokens, 8)
value = torch.rand(B, H, max_tokens, 8)
output = flex_attention(query, key, value, score_mod=attention_steering)
output_scale_down = flex_attention(query, key, value, score_mod=attention_steering_scale_down)
I'm also considering adding a page table (because the mapping between batch-id <-> highlighted_token 0th dim index can change) to make this complete.
If you think this is interesting, I can make a PR into the attention-gym to show case this example?
Yeah sounds super interesting, feel free to open up a PR!
I really love FlexAttention and attention-gym! Thanks for sharing this great resource
I'm trying to implement something similar to PASTA. The core logic is simple:
Now I want to transform this algorithm into matrix form. I tried to add
score
with a static matrix, but that does not work. I also found it hard to understand the dimension of thescore
.Maybe I'm missing something here. Is it possible to implement this logic using FlexAttention (assuming these scalars are still fixed)?