nnzhan / MTGNN

MIT License
762 stars 216 forks source link

The adj in in the forward should be sum(0) ? #12

Closed fuleying closed 3 years ago

fuleying commented 3 years ago
class nconv(nn.Module):
    def __init__(self):
        super(nconv,self).__init__()

    def forward(self,x, A):
        x = torch.einsum('ncvl,vw->ncwl',(x,A))  # aggregate by each columns of A
        return x.contiguous()

class mixprop(nn.Module):
    def __init__(self,c_in,c_out,gdep,dropout,alpha):
        super(mixprop, self).__init__()
        self.nconv = nconv()
        self.mlp = linear((gdep+1)*c_in,c_out)
        self.gdep = gdep
        self.dropout = dropout
        self.alpha = alpha

    def forward(self,x,adj):
        adj = adj + torch.eye(adj.size(0)).to(x.device)
        d = adj.sum(1)  # Here should be sum(0), because the column represent its neighbors ?
        h = x
        out = [h]
        a = adj / d.view(-1, 1)
        for i in range(self.gdep):
            h = self.alpha*x + (1-self.alpha)*self.nconv(h,a)
            out.append(h)
        ho = torch.cat(out,dim=1)
        ho = self.mlp(ho)
        return ho

@nnzhan

ntubiolin commented 3 years ago

Possible duplicate of #9 ? I have come up with this issue before.

fuleying commented 3 years ago

@ntubiolin Sorry for that, I didn't see your issue in the opened issue list. will close this issue.