HongyangGao / Graph-U-Nets

Pytorch implementation of Graph U-Nets (ICML19)
http://proceedings.mlr.press/v97/gao19a/gao19a.pdf
GNU General Public License v3.0
513 stars 100 forks source link

The order of uppooling and skip-connection seems wrong #17

Closed LingxiaoShawn closed 4 years ago

LingxiaoShawn commented 4 years ago

Hi Hongyang, Based on your paper, I think the following code in ops.py has some ordering problem

   def forward(self, g, h):
        adj_ms = []
        indices_list = []
        down_outs = []
        hs = []
        org_h = h
        for i in range(self.l_n):
            h = self.down_gcns[i](g, h)
            adj_ms.append(g)
            down_outs.append(h)
            g, h, idx = self.pools[i](g, h)
            indices_list.append(idx)

        h = self.bottom_gcn(g, h)

        for i in range(self.l_n):
            up_idx = self.l_n - i - 1
            g, idx = adj_ms[up_idx], indices_list[up_idx]
            g, h = self.unpools[i](g, h, down_outs[up_idx], idx)
            h = self.up_gcns[i](g, h)
            h = h.add(down_outs[up_idx])
            hs.append(h)
        h = h.add(org_h)
        hs.append(h)
        return hs

In my understanding the correct order should be


        for i in range(self.l_n):
            up_idx = self.l_n - i - 1
            g, idx = adj_ms[up_idx], indices_list[up_idx]

            g, h = self.unpools[up_idx](g, h, down_outs[up_idx], idx)

            h = h.add(down_outs[up_idx])
            hs.append(h)

            h = self.up_gcns[up_idx](g, h)

        h = h.add(org_h)
        hs.append(h)
        return hs`

please let me know whether I'm correct. Thank you again for sharing the structured code!

HongyangGao commented 4 years ago

Yes. Actually you can do either way. Due to the fact that graph data are very easy to overfit, using addition after gcn may help to avoid this problem.