vgsatorras / egnn

MIT License
420 stars 75 forks source link

Clean code for interpretation #14

Open LuisHDBueno opened 4 months ago

LuisHDBueno commented 4 months ago

I was studying your paper and, when looking into the implementation, came upon the following in .\models\egnn_clean\egnn_clean.py:

    self.edge_mlp = nn.Sequential(
        nn.Linear(input_edge + edge_coords_nf + edges_in_d, hidden_nf),
        act_fn,
        nn.Linear(hidden_nf, hidden_nf),
        act_fn)

    self.node_mlp = nn.Sequential(
        nn.Linear(hidden_nf + input_nf, hidden_nf),
        act_fn,
        nn.Linear(hidden_nf, output_nf))

    layer = nn.Linear(hidden_nf, 1, bias=False)
    torch.nn.init.xavier_uniform_(layer.weight, gain=0.001)

    coord_mlp = []
    coord_mlp.append(nn.Linear(hidden_nf, hidden_nf))
    coord_mlp.append(act_fn)
    coord_mlp.append(layer)
    if self.tanh:
        coord_mlp.append(nn.Tanh())
    self.coord_mlp = nn.Sequential(*coord_mlp)

I was somewhat confused as why self.coord_mlp was different from self.edge_mlp and self.node_mlp. I had to wrestle with myself until I was convinced they are the same.

Therefore, for clarity reasons, I would recomend the following refactor:

    self.coord_mlp = nn.Sequential(
        nn.Linear(hidden_nf, hidden_nf),
        act_fn,
        layer,
        nn.Tanh() if self.tanh else nn.Identity())