openai / sparse_attention

Examples of using sparse attention, as in "Generating Long Sequences with Sparse Transformers"
1.51k stars 191 forks source link

Problem with reproducing "strided" attention scheme from the paper #7

Open krishnadubba opened 5 years ago

krishnadubba commented 5 years ago

HI, I am trying to visualize the attention schemes using this code. Basically trying to reproduce Fig:3 from the paper. I could reproduce the "fixed" attention scheme as shown below:

fixed_sparse_attn

The problem is I could not reproduce the "strided" scheme (Fig 3.b from paper). All I get is the following no matter what parameters I try:

strided_wrong

If I change some code then I can get the correct "strided" version as shown in the paper. The following is after some code changes:

strided_correct

Did anyone face the same issue?

pengfeiZhao1993 commented 4 years ago

After reading this code -- "attention.py", I find this base code only contains separate implementations of strided attention, called "first / second step of strided attention" within it. Therefore, you perhaps need to implement a integral version of strided attention by yourself with each head corresponding to one of aforementioned two steps for a two head sparse self-attention.

benathi commented 1 year ago

@krishnadubba Have you successfully implement the strided version btw? Could you share the code change?

jaindhairyahere commented 5 months ago

@krishnadubba Have you successfully implement the strided version btw? Could you share the code change?

I was able to reproduce the patterns using this function. image image

`

def sparse_attention_mask(n_tokens, stride_length=3, c=2):

  x = tf.reshape(tf.range(n_tokens), [n_tokens, 1])

  y = tf.transpose(x)

  z = tf.zeros((n_tokens,n_tokens))

  Q = z + x

  K = z + y

  causal_attention_mask = (Q>=K)

  fixed_mask_1 = tf.equal(Q//stride_length, K//stride_length)
  fixed_mask_2 = tf.logical_and(tf.math.floormod(K, stride_length) >= stride_length-c, tf.math.floormod(K, stride_length)<=stride_length)
  combined_mask_fixed = tf.logical_and(causal_attention_mask, tf.logical_or(fixed_mask_1, fixed_mask_2))

  stride_mask_1 = tf.less_equal(Q-K, stride_length)
  stride_mask_2 = tf.equal(tf.math.floormod(Q-K, stride_length), 0)
  combined_mask_stride = tf.logical_and(causal_attention_mask, tf.logical_or(stride_mask_1, stride_mask_2))

  return tf.reshape(combined_mask_fixed, [1, 1, n_tokens, n_tokens]), tf.reshape(combined_mask_stride, [1, 1, n_tokens, n_tokens])`