pytorch-labs / attention-gym

Helpful tools and examples for working with flex-attention
BSD 3-Clause "New" or "Revised" License
484 stars 24 forks source link

[Feature request] how to merge Blockmask? #39

Closed foreverpiano closed 1 month ago

foreverpiano commented 2 months ago

The blockmask layout is [b, h, S, S]. I want to have some dynamism in [b, h]. For example, in head 0-15, we use causal mask, and in head 16-31 we use sliding_window_causal.

There are two things:

  1. we actually use two kinds of mask, so is there a way to have a pointer version of mask and avoid copying for head_num times?
  2. how to write a single kernel to create this mask? ( when I try to use if h < 16.... in the function, it gives torch.compile Dynamic>1 error)

FYI


def sliding_window_causal(b, h, q_idx, kv_idx):
    causal_mask = q_idx >= kv_idx
    window_mask = q_idx - kv_idx <= SLIDING_WINDOW 
    return causal_mask & window_mask

def causal(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx

def half_causal_and_half_sliding_window(b, h, q_idx, kv_idx):
TODO:

@drisspg @Chillee

JCBrouwer commented 2 months ago

I'm pretty sure b and h are also just index tensors like q_idx and kv_idx, so you should just integrate them into your mask:

def half_causal_and_half_sliding_window(b, h, q_idx, kv_idx):
    causal_mask = (h < 16) & (q_idx >= kv_idx)
    window_mask = (h >= 16) & (q_idx - kv_idx <= SLIDING_WINDOW)
    return causal_mask | window_mask
foreverpiano commented 2 months ago

@JCBrouwer make sense. Thanks for your help.