vgsatorras / egnn

MIT License
420 stars 75 forks source link

About the implemenation of Eq4 of the EGNN paper. #5

Open Junyoungpark opened 2 years ago

Junyoungpark commented 2 years ago

Hi,

I wonder about the implementation of Eq4 of the EGNN paper. According to equation 4 from the paper, the position update takes into accounts all interactions among the nodes in the graph. In that sense, the meaning of normalizer C=1/(M-1) makes more sense. However, according to the code, the aggregations to update the coord are done only for the existing edges. May I ask to clarify which one is the expected behavior of EGNN?

def coord_model(self, coord, edge_index, coord_diff, edge_feat):
        row, col = edge_index
        trans = coord_diff * self.coord_mlp(edge_feat)
        if self.coords_agg == 'sum':
            agg = unsorted_segment_sum(trans, row, num_segments=coord.size(0))
        elif self.coords_agg == 'mean':
            agg = unsorted_segment_mean(trans, row, num_segments=coord.size(0))
        else:
            raise Exception('Wrong coords_agg parameter' % self.coords_agg)
        coord = coord + agg
        return coord

Thank you :) Junyoung

vgsatorras commented 2 years ago

Hi,

As you very well noticed this implementation can handle sparse graphs. And edges are provided in the argument edge_index. Despite that, we can still use this implementation in a fully connected setting (all interactions) by passing fully connected edges in the argument edge_index. This is what we did in the code.

The reason we used a sparse implementation while providing fully connected edges is that it can be easily extended to sparse message passing settings in the future by just replacing the content of the argument edge_index.

Best, Victor

Junyoungpark commented 2 years ago

Hi Victor,

Thanks for the reply. Thanks to your explanation, I've got a couple of questions after reading the comment.

The coverage of aggregation operations of Eq (4), (5) In the beginning, I thought that Eq. (4) is to aggregate the interaction of all particles (nodes) to the target node, as the summation is done for all nodes but the target. I thought It resembles/reflects that the physical forces exert the effect even though two particles are located at an infinitely far distance. On the other hand, the summation of Eq (5) is done in a "local" manner.

I guess, according to the comment, both summations are done locally, at least in implementation. However, it is the same in implementation, whether local or global, because the graphs are fully connected. Do I understand correctly?

About the coverage of summations in Eq(4), Eq(7) These may be pretty similar questions from above. I genuinely believe the paper takes a lot of effort to use correct notations for conveying concise ideas. Hence, it seems like the summations done for i !=j and j in N(i) are carried out to realize the different ideas. For instance, the former aggregates global features, and the latter is for local aggregations. Besides the implementation, may I ask you to explain this part more?

Thanks for helping me out to better understand this cool paper :)

Sincerely, Junyoung