Closed CurryTang closed 1 year ago
@CurryTang
Thank you so much for reaching out and for your interest in our project. I'm glad you're finding it valuable!
I'd be happy to explain the implementation detail you've inquired about. With Rotary Positional Encoding (RoPE), we indeed encode positional information into both the query and key embeddings. However, for our graph-enhanced prefix—what we call the adapter
—it is used solely as keys and values, without needing to compute attention outputs for the prefix tokens themselves.
Here's how we apply this in the code: For the original sequence queries and keys, we apply RoPE as shown in:
# model.py line 237
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
And for the graph-enhanced prefix keys, we apply RoPE:
# model.py line 248
adapter_k = apply_rotary_emb_single(adapter_k, freqs_cis=freqs_cis_prefix)
Then, the attention computation is like this:
# model.py line 273-277
keys = torch.cat([adapter_k, keys], dim=2) # keys: [batch_size, prefix_len+seqlen, dim]
values = torch.cat([adapter_v, values], dim=2) # values: [batch_size, prefix_len+seqlen, dim]
output = self._forward_scaled_dot_product_attention(xq, keys, values, attention_mask=mask) # xq: [batch_size, seqlen, dim] output: [batch_size, seqlen, dim]
I hope this clears things up! If there's anything more you'd like to discuss or if you need further clarification, please feel free to ask. Your questions are always welcome.
Thanks for your detailed explanation👍
Hi! Thanks for your wonderful work. I have some questions about the minor implementation details. For example,
In the adapter of the attention module, the rotary PE is only applied to the key vector. Is there any reference codebase for this design? Many thanks.