alibaba / GraphTranslator

GraphTranslator:Aligning Graph Model to Large Language Model for Open-ended Tasks
BSD 3-Clause "New" or "Revised" License
68 stars 12 forks source link

About the interactions between node embedding and query tokens. #9

Closed hhy-huang closed 2 months ago

hhy-huang commented 3 months ago

Dear authors and contributors, thank you very much for such remarkable work.

As the codes below shows, is the pre-trained GNN output feature in each case which interacts with 32 query tokens just the feature of one node?

If so, for each case in one batch, are the 32 query tokens just utilized to extract useful feature from only one node feature within cross-attention ?

behavior_embeds = torch.unsqueeze(samples[1], dim=1)    # samples[1]: [batch_size, 768], node features
text = samples[2]                                       # list of batch_size summaries
behavior_embeds = behavior_embeds.to(self.device)       # [batch_size, 1, 768]
behavior_atts = torch.ones(behavior_embeds.size()[:-1], dtype=torch.long).to(behavior_embeds.device)
                                                                # [batch_size, 1]
query_tokens = self.query_tokens.expand(behavior_embeds.shape[0], -1, -1)   # repeated in each batch
                                                                # [batch_size, 32, 768]
query_output = self.Qformer.bert(
            query_embeds=query_tokens,                          # [batch_size, 32, 768]
            encoder_hidden_states=behavior_embeds,              # [batch_size, 1, 768] cross-attention
            encoder_attention_mask=behavior_atts,               # [batch_size, 1]
            use_cache=True,
            return_dict=True,
        )       # last_hidden_state: [batch_size, 32, 768]

        behavior_feats = F.normalize(                           # [batch_size, 32, 256]
            self.behavior_proj(query_output.last_hidden_state), dim=-1
        )
smw1996 commented 3 months ago

First and foremost, we would like to express our gratitude for your interest in our open-source code repository. Yes, the input to the transformer module consists of the node embeddings from a single node post-GNN processing, while the 32x768 query tokens are trainable. We believe that, after the training of the graph translator is complete, the query tokens should encapsulate the methodology of effectively extracting semantic information from both the node itself and its neighbors. This would ensure that the output transcends the original abstract node embeddings, embodying instead compressed token embeddings of 32 tokens, which are imbued with semantic information and can be comprehended by a language model. I hope this addresses your question comprehensively.