Hi, can you explain a bit more about the implementation of axial attention? Is there a reason for performing the calculation step-wise, is the goal to conserve memory?
for start in range(0, num_rows, max_rows):
attn_weights = self.compute_attention_weights(
x[start : start + max_rows],
scaling,
self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask[:, start : start + max_rows]
if self_attn_padding_mask is not None
else None,
)
attns += attn_weights
Secondly, your Wq Wk Wv matrices have bias terms enabled by default, was there a reasoning behind this?
Hi, can you explain a bit more about the implementation of axial attention? Is there a reason for performing the calculation step-wise, is the goal to conserve memory?
Secondly, your Wq Wk Wv matrices have bias terms enabled by default, was there a reasoning behind this?