Open krishnadubba opened 5 years ago
After reading this code -- "attention.py", I find this base code only contains separate implementations of strided attention, called "first / second step of strided attention" within it. Therefore, you perhaps need to implement a integral version of strided attention by yourself with each head corresponding to one of aforementioned two steps for a two head sparse self-attention.
@krishnadubba Have you successfully implement the strided version btw? Could you share the code change?
@krishnadubba Have you successfully implement the strided version btw? Could you share the code change?
I was able to reproduce the patterns using this function.
`
def sparse_attention_mask(n_tokens, stride_length=3, c=2):
x = tf.reshape(tf.range(n_tokens), [n_tokens, 1])
y = tf.transpose(x)
z = tf.zeros((n_tokens,n_tokens))
Q = z + x
K = z + y
causal_attention_mask = (Q>=K)
fixed_mask_1 = tf.equal(Q//stride_length, K//stride_length)
fixed_mask_2 = tf.logical_and(tf.math.floormod(K, stride_length) >= stride_length-c, tf.math.floormod(K, stride_length)<=stride_length)
combined_mask_fixed = tf.logical_and(causal_attention_mask, tf.logical_or(fixed_mask_1, fixed_mask_2))
stride_mask_1 = tf.less_equal(Q-K, stride_length)
stride_mask_2 = tf.equal(tf.math.floormod(Q-K, stride_length), 0)
combined_mask_stride = tf.logical_and(causal_attention_mask, tf.logical_or(stride_mask_1, stride_mask_2))
return tf.reshape(combined_mask_fixed, [1, 1, n_tokens, n_tokens]), tf.reshape(combined_mask_stride, [1, 1, n_tokens, n_tokens])`
HI, I am trying to visualize the attention schemes using this code. Basically trying to reproduce Fig:3 from the paper. I could reproduce the "fixed" attention scheme as shown below:
The problem is I could not reproduce the "strided" scheme (Fig 3.b from paper). All I get is the following no matter what parameters I try:
If I change some code then I can get the correct "strided" version as shown in the paper. The following is after some code changes:
Did anyone face the same issue?