dinobby / MAGDi

The code implementation of MAGDi: Structured Distillation of Multi-Agent Interaction Graphs Improves Reasoning in Smaller Language Models. Paper: https://arxiv.org/abs/2402.01620
MIT License
30 stars 6 forks source link

It seems that the GCN loss is not updating the model parameters. #2

Open lianshan01 opened 3 months ago

lianshan01 commented 3 months ago
    gcn_output, logits = self.gcn(graph_batch.x, graph_batch.edge_index)
    graph_batch.y = graph_batch.y.to(logits.device)
    node_loss = ce_cri(logits, graph_batch.y)

The input of the GCN loss seems to be independent of the model.

dinobby commented 3 months ago

Hi @lianshan01, thanks for your comment! The input of GCN is the node embeddings obtained from the Multi-Agent Interaction Graphs (MAGs) using the same underlying decoder (Mistral), and we are updating the model through the sum of three objectives, i.e., the model is doing multi-task learning. Hope this clarifies your question!