Closed foreverpiano closed 1 month ago
I'm pretty sure b and h are also just index tensors like q_idx and kv_idx, so you should just integrate them into your mask:
def half_causal_and_half_sliding_window(b, h, q_idx, kv_idx):
causal_mask = (h < 16) & (q_idx >= kv_idx)
window_mask = (h >= 16) & (q_idx - kv_idx <= SLIDING_WINDOW)
return causal_mask | window_mask
@JCBrouwer make sense. Thanks for your help.
The blockmask layout is [b, h, S, S]. I want to have some dynamism in [b, h]. For example, in head 0-15, we use causal mask, and in head 16-31 we use sliding_window_causal.
There are two things:
if h < 16....
in the function, it gives torch.compile Dynamic>1 error)FYI
@drisspg @Chillee