mistyreed63849 / Graph-LLM

GraphLLM: Boosting Graph Reasoning Ability of Large Language Model
https://arxiv.org/abs/2310.05845
93 stars 10 forks source link

Questions about some implementation details #1

Closed CurryTang closed 1 year ago

CurryTang commented 1 year ago

Hi! Thanks for your wonderful work. I have some questions about the minor implementation details. For example,

            adapter_key, adapter_value = adapter
            adapter_len = adapter_key.shape[1]

            adapter_k = self.wk(adapter_key)
            adapter_k = adapter_k.view(bsz, adapter_len, self.n_heads, self.head_dim)
            adapter_v = self.wv(adapter_value)
            adapter_v = adapter_v.view(bsz, adapter_len, self.n_heads, self.head_dim)

            adapter_k = apply_rotary_emb_single(adapter_k, freqs_cis=freqs_cis_prefix)

            adapter_k = adapter_k.transpose(1, 2)
            adapter_v = adapter_v.transpose(1, 2)

In the adapter of the attention module, the rotary PE is only applied to the key vector. Is there any reference codebase for this design? Many thanks.

godcherry commented 1 year ago

@CurryTang

Thank you so much for reaching out and for your interest in our project. I'm glad you're finding it valuable!

I'd be happy to explain the implementation detail you've inquired about. With Rotary Positional Encoding (RoPE), we indeed encode positional information into both the query and key embeddings. However, for our graph-enhanced prefix—what we call the adapter—it is used solely as keys and values, without needing to compute attention outputs for the prefix tokens themselves.

Here's how we apply this in the code: For the original sequence queries and keys, we apply RoPE as shown in:

# model.py line 237
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

And for the graph-enhanced prefix keys, we apply RoPE:

# model.py line 248
adapter_k = apply_rotary_emb_single(adapter_k, freqs_cis=freqs_cis_prefix)

Then, the attention computation is like this:

#  model.py line 273-277

keys = torch.cat([adapter_k, keys], dim=2)    #  keys: [batch_size, prefix_len+seqlen, dim]
values = torch.cat([adapter_v, values], dim=2)   #  values: [batch_size, prefix_len+seqlen, dim]
output = self._forward_scaled_dot_product_attention(xq, keys, values, attention_mask=mask)  #  xq: [batch_size, seqlen, dim] output: [batch_size, seqlen, dim]

I hope this clears things up! If there's anything more you'd like to discuss or if you need further clarification, please feel free to ask. Your questions are always welcome.

CurryTang commented 1 year ago

Thanks for your detailed explanation👍