pytorch-labs / attention-gym

Helpful tools and examples for working with flex-attention
BSD 3-Clause "New" or "Revised" License
433 stars 22 forks source link

[question] Possible to implement Attention Steering? #10

Open GindaChen opened 2 months ago

GindaChen commented 2 months ago

I really love FlexAttention and attention-gym! Thanks for sharing this great resource

I'm trying to implement something similar to PASTA. The core logic is simple:

highlighted_token_idx = {
    # (batch_id, head_id) -> (token_idxs, scalar_factor)
    (0, 2): ([1, 3, 5], 1.2),
    (1, 3): ([2, 4], 1.5),
}

def attention_steering(score, b, h, q_idx, kv_idx):
    for (bi, head_idx), (tok_idx, scalar) in highlighted_token_idx.items():
        score[bi, head_idx, :, tok_idx] += scalar
    return score

...
# This won't work, but just conceptually showing the possibility:
output = flex_attention(query, key, value, score_mod=attention_steering)

Now I want to transform this algorithm into matrix form. I tried to add score with a static matrix, but that does not work. I also found it hard to understand the dimension of the score.

Maybe I'm missing something here. Is it possible to implement this logic using FlexAttention (assuming these scalars are still fixed)?

drisspg commented 2 months ago

Hey took some liberties with the interpretation of your problem, need to read through the links you shared.

My gut reaction is something like this:

import torch
from torch.nn.attention.flex_attention import flex_attention

B, H = 2, 3
max_tokens = 128

higlighted_tokens = torch.zeros(B, H, 2)
scales = torch.rand(B, H, 1)

higlighted_tokens[0, 0] = torch.tensor([2, 5])
higlighted_tokens[0, 1] = torch.tensor([3, 7])
higlighted_tokens[0, 2] = torch.tensor([1, 4])

def attention_steering(score, b, h, q_idx, kv_idx):
    start_stop_pairs = higlighted_tokens[b, h]
    valid_q_token = (q_idx >= start_stop_pairs[0]) & (q_idx < start_stop_pairs[1])
    valid_kv_token = (kv_idx >= start_stop_pairs[0]) & (kv_idx < start_stop_pairs[1])
    scale_ = scales[b, h, torch.tensor(0)]
    return torch.where(valid_q_token & valid_kv_token, score + scale_, score)

query = torch.rand(B, H, max_tokens, 8)
key = torch.rand(B, H, max_tokens, 8)
value = torch.rand(B, H, max_tokens, 8)
# This won't work, but just conceptually showing the possibility:
output = flex_attention(query, key, value, score_mod=attention_steering)

Assumptions the tokens to be steered lie within a contiguous span. There is only 1 scale to add per each sequence.

This impl, essentially gets the beginning and end of that span, check to see if the current q_idx and kv_idx lie within that span, if see we add the signular scalar otherwise we return the scale as is.

GindaChen commented 2 months ago

@drisspg Thanks for helping out! I modify the script a bit and found a way that works (with multiple spans, but 1 scale factor per sequence):

import torch
from torch.nn.attention.flex_attention import flex_attention

B, H = 8, 8
max_tokens = 1024

higlighted_tokens = torch.zeros(B, H, max_tokens, dtype=torch.bool)
scales = torch.rand(B, H, 1)

highlighted_token_idx = {
    # (batch_id, head_id) -> (token_idxs, scalar_factor)
    (0, 2): ([1, 3, 5], 1.2),
    (1, 3): ([2, 4], 1.5),
}

# Setup highlighted tokens and scaling factor
for (b, h), (token_idxs, scalar_factor) in highlighted_token_idx.items():
    higlighted_tokens[b, h, token_idxs] = True
    scales[b, h, 0] = scalar_factor

def attention_steering(score, b, h, q_idx, kv_idx):
    start_stop_pairs = higlighted_tokens[b, h]
    valid_q_token = start_stop_pairs[q_idx]
    valid_kv_token = start_stop_pairs[kv_idx]
    scale_ = scales[b, h, torch.tensor(0)]
    return torch.where(valid_q_token & valid_kv_token, score + scale_, score)

def attention_steering_scale_down(score, b, h, q_idx, kv_idx):
    start_stop_pairs = higlighted_tokens[b, h]
    valid_q_token = ~start_stop_pairs[q_idx]
    valid_kv_token = ~start_stop_pairs[kv_idx]
    scale_ = scales[b, h, torch.tensor(0)]
    return torch.where(valid_q_token & valid_kv_token, score - scale_, score)

query = torch.rand(B, H, max_tokens, 8)
key = torch.rand(B, H, max_tokens, 8)
value = torch.rand(B, H, max_tokens, 8)
output = flex_attention(query, key, value, score_mod=attention_steering)
output_scale_down = flex_attention(query, key, value, score_mod=attention_steering_scale_down)

I'm also considering adding a page table (because the mapping between batch-id <-> highlighted_token 0th dim index can change) to make this complete.

If you think this is interesting, I can make a PR into the attention-gym to show case this example?

drisspg commented 2 months ago

Yeah sounds super interesting, feel free to open up a PR!