DeepGraphLearning / GearNet

GearNet and Geometric Pretraining Methods for Protein Structure Representation Learning, ICLR'2023 (https://arxiv.org/abs/2203.06125)
MIT License
253 stars 28 forks source link

The pre-trained GearNet-Edge model for Fold Classification #41

Open arantir123 opened 11 months ago

arantir123 commented 11 months ago

Thank you for your amazing work! I found that for the Fold Classification task, the GearNet-Edge model was implemented based on the GearNetIEConv script rather than the GearNet script, which has some detail differences (e.g., extra input embedding and ieconv layers). Based on this, I would like to ask whether you could provide the pretrained GearNet-Edge model based on multiview contrast learning and the GearNetIEConv script for Fold Classification (rather than based on GearNet script for EC task)? Thank you.

Oxer11 commented 11 months ago

Hi, the config file for GearNet-Edge-IEConv on Fold is config/Fold3D/gearnet_edge_ieconv.yaml. The pre-trained checkpoints of GearNet-Edge can be found at https://zenodo.org/record/7723075.

arantir123 commented 11 months ago

Thank you. It seems that fold_mc_gearnet_edge_ieconv.pth includes the encoder and decoder parameters after finetuning. I just would like to do some experiments on my own, i.e., I would like to have the pretrained GearNet-Edge-IEConv encoder before finetuning, obtain the finetuning configuration script and corresponding running command (e.g., how many GPUs/batch size were actually used in finetuning), and do the finetuning experiment on my own. Whether it is convenient for you to provide these for me? Thank you very much.

Oxer11 commented 11 months ago

I see. The original pre-trained checkpoints were deleted by my cluster. I've pre-trained a new GearNet-Edge-IEConv recently. You can download the checkpoint from this link and have a try. Please ping me if there is any problem with the checkpoint.

For finetuning, just use the following command

python script/downstream.py -c config/downstream/Fold3D/gearnet_edge_ieconv.yaml --gpus [0] --ckpt <path_to_your_model>
arantir123 commented 11 months ago

I see. The original pre-trained checkpoints were deleted by my cluster. I've pre-trained a new GearNet-Edge-IEConv recently. You can download the checkpoint from this link and have a try. Please ping me if there is any problem with the checkpoint.

For finetuning, just use the following command

python script/downstream.py -c config/downstream/Fold3D/gearnet_edge_ieconv.yaml --gpus [0] --ckpt <path_to_your_model>

Thank you very much. I will have a try.

arantir123 commented 10 months ago

I see. The original pre-trained checkpoints were deleted by my cluster. I've pre-trained a new GearNet-Edge-IEConv recently. You can download the checkpoint from this link and have a try. Please ping me if there is any problem with the checkpoint.

For finetuning, just use the following command

python script/downstream.py -c config/downstream/Fold3D/gearnet_edge_ieconv.yaml --gpus [0] --ckpt <path_to_your_model>

Hi, it seems that the model contained in the above link is not in line with/cannot fit the model (size) in official https://zenodo.org/record/7723075 (the hidden dimensions of each layer are different), I guess the model in https://zenodo.org/record/7723075 is based on the following new implementation version of GearNet-Edge-IEConv (with extra input embedding etc).

@R.register("models.GearNetIEConv") class GearNetIEConv(nn.Module, core.Configurable):

def __init__(self, input_dim, embedding_dim, hidden_dims, num_relation, edge_input_dim=None,
             batch_norm=False, activation="relu", concat_hidden=False, short_cut=True, 
             readout="sum", dropout=0, num_angle_bin=None, layer_norm=False, use_ieconv=False):
    super(GearNetIEConv, self).__init__()
    print('using GearNetIEConv.')

    if not isinstance(hidden_dims, Sequence):
        hidden_dims = [hidden_dims]
    self.input_dim = input_dim
    self.embedding_dim = embedding_dim
    self.output_dim = sum(hidden_dims) if concat_hidden else hidden_dims[-1]
    self.dims = [embedding_dim if embedding_dim > 0 else input_dim] + list(hidden_dims)
    self.edge_dims = [edge_input_dim] + self.dims[:-1]
    self.num_relation = num_relation
    self.concat_hidden = concat_hidden
    self.short_cut = short_cut
    self.num_angle_bin = num_angle_bin
    self.short_cut = short_cut
    self.concat_hidden = concat_hidden
    self.layer_norm = layer_norm
    self.use_ieconv = use_ieconv  

    if embedding_dim > 0:
        self.linear = nn.Linear(input_dim, embedding_dim)
        self.embedding_batch_norm = nn.BatchNorm1d(embedding_dim)

    self.layers = nn.ModuleList()
    self.ieconvs = nn.ModuleList()
    for i in range(len(self.dims) - 1):
        # note that these layers are from gearnet.layer instead of torchdrug.layers
        self.layers.append(layer.GeometricRelationalGraphConv(self.dims[i], self.dims[i + 1], num_relation,
                                                               None, batch_norm, activation))
        if use_ieconv:
            self.ieconvs.append(layer.IEConvLayer(self.dims[i], self.dims[i] // 4, 
                                self.dims[i+1], edge_input_dim=14, kernel_hidden_dim=32))
    if num_angle_bin:
        self.spatial_line_graph = layers.SpatialLineGraph(num_angle_bin)
        self.edge_layers = nn.ModuleList()
        for i in range(len(self.edge_dims) - 1):
            self.edge_layers.append(layer.GeometricRelationalGraphConv(
                self.edge_dims[i], self.edge_dims[i + 1], num_angle_bin, None, batch_norm, activation))

    if layer_norm:
        self.layer_norms = nn.ModuleList()
        for i in range(len(self.dims) - 1):
            self.layer_norms.append(nn.LayerNorm(self.dims[i + 1]))

    self.dropout = nn.Dropout(dropout)

    if readout == "sum":
        self.readout = layers.SumReadout()
    elif readout == "mean":
        self.readout = layers.MeanReadout()
    else:
        raise ValueError("Unknown readout `%s`" % readout)

def get_ieconv_edge_feature(self, graph):
    u = torch.ones_like(graph.node_position)
    u[1:] = graph.node_position[1:] - graph.node_position[:-1]
    u = F.normalize(u, dim=-1)
    b = torch.ones_like(graph.node_position)
    b[:-1] = u[:-1] - u[1:]
    b = F.normalize(b, dim=-1)
    n = torch.ones_like(graph.node_position)
    n[:-1] = torch.cross(u[:-1], u[1:])
    n = F.normalize(n, dim=-1)

    local_frame = torch.stack([b, n, torch.cross(b, n)], dim=-1)

    node_in, node_out = graph.edge_list.t()[:2]
    t = graph.node_position[node_out] - graph.node_position[node_in]
    t = torch.einsum('ijk, ij->ik', local_frame[node_in], t)
    r = torch.sum(local_frame[node_in] * local_frame[node_out], dim=1)
    delta = torch.abs(graph.atom2residue[node_in] - graph.atom2residue[node_out]).float() / 6
    delta = delta.unsqueeze(-1)

    return torch.cat([
        t, r, delta, 
        1 - 2 * t.abs(), 1 - 2 * r.abs(), 1 - 2 * delta.abs()
    ], dim=-1)

def forward(self, graph, input, all_loss=None, metric=None):
    hiddens = []
    layer_input = input
    if self.embedding_dim > 0:
        layer_input = self.linear(layer_input)
        layer_input = self.embedding_batch_norm(layer_input)
    if self.num_angle_bin:
        line_graph = self.spatial_line_graph(graph)
        edge_hidden = line_graph.node_feature.float()
    else:
        edge_hidden = None
    ieconv_edge_feature = self.get_ieconv_edge_feature(graph)

    for i in range(len(self.layers)):
        # edge message passing
        if self.num_angle_bin:
            edge_hidden = self.edge_layers[i](line_graph, edge_hidden)
        hidden = self.layers[i](graph, layer_input, edge_hidden)
        # ieconv layer
        if self.use_ieconv:
            hidden = hidden + self.ieconvs[i](graph, layer_input, ieconv_edge_feature)
        hidden = self.dropout(hidden)

        if self.short_cut and hidden.shape == layer_input.shape:
            hidden = hidden + layer_input

        if self.layer_norm:
            hidden = self.layer_norms[i](hidden)
        hiddens.append(hidden)
        layer_input = hidden

    if self.concat_hidden:
        node_feature = torch.cat(hiddens, dim=-1)
    else:
        node_feature = hiddens[-1]
    graph_feature = self.readout(graph, node_feature)

    return {
        "graph_feature": graph_feature,
        "node_feature": node_feature
    }