DeepGraphLearning / GearNet

GearNet and Geometric Pretraining Methods for Protein Structure Representation Learning, ICLR'2023 (https://arxiv.org/abs/2203.06125)
MIT License
263 stars 27 forks source link

Asking about implementation of series connection of PLM & GNN in the FusionNetwork. #53

Closed yunxiaoliCB closed 11 months ago

yunxiaoliCB commented 11 months ago

Hi, I've learned a lot from this great work. Thank you for presenting it in the paper and here!

I wanted to ask about implementation of series connection of PLM & GNN in the FusionNetwork. In the PLM+GNN paper ( Zhang, Z. et al. Enhancing Protein Language Models with Structure-based Encoder and Pre-training. Arxiv (2023) doi:10.48550/arxiv.2303.06275), the authors tested three ways of fusing PLM & GNN and decided to use the series connection. The series connection is described as

Series: we replace the node features of GearNet with the output of ESM-1b and use the output of GearNet as final representations.

In the implementation of FusionNetwork. I saw it indeed uses the output of ESM-1b as the node features of GearNet, but then seems to use the output of GearNet concatenated with the output of ESM-1b as final representations (pasted below). So which is the way that the authors found most effective? Shall one use sole output from GearNet or the concatenated output?

    def forward(self, graph, input, all_loss=None, metric=None):
        output1 = self.sequence_model(graph, input, all_loss, metric)
        node_output1 = output1.get("node_feature", output1.get("residue_feature"))
        output2 = self.structure_model(graph, node_output1, all_loss, metric)
        node_output2 = output2.get("node_feature", output2.get("residue_feature"))
        node_feature = torch.cat([node_output1, node_output2], dim=-1)
        graph_feature = torch.cat([
            output1['graph_feature'], 
            output2['graph_feature']
        ], dim=-1)
        return {
            "graph_feature": graph_feature,
            "node_feature": node_feature
        }

If possible, could you please share some configurations on trying out the "cross" style (quote below) of fusing PLM & GNN? I am interested in testing this option and wanted to learn about the configurations of the transformer (number of layers, hidden dims, number of head) that you have tried.

Cross: we concatenate the output of ESM-1b and GearNet and then feed them into a transformer to perform cross-attention between modalities. The output of the transformer will be used asfinal representations.

Oxer11 commented 11 months ago

Hi, sorry for the late update.

You can find the latest update of our paper and code here: https://x.com/Oxer22/status/1717167378067316854?s=20 https://arxiv.org/abs/2303.06275 https://github.com/DeepGraphLearning/ESM-GearNet

For your question, yes. Using the output of GearNet concatenated with the output of ESM-1b as final representations is the best.

yunxiaoliCB commented 11 months ago

Thank you for pointing me there. KUDOs on the new release!