xianggebenben / GraphSL

Graph Source Localization Library
MIT License
8 stars 2 forks source link

some question about the issue of the SLVAE algorithm code #19

Closed waityousea closed 2 months ago

waityousea commented 2 months ago

The new version GraphSL’s code of the slvae model is different of the old. Meanwhile, it is different of the https://github.com/triplej0079/SLVAE.

xianggebenben commented 2 months ago

Thank you for the interest in our GraphSL library. The current code of the SLVAE is revised based on the source code, as shown in the version log. We modified the encoder of the SLVAE because it demonstrates better performance than the previous version. Please let me know if you have any questions.

waityousea commented 2 months ago

I know you made modifications and merges in the VAE section of the code.

Is this part of the code no longer needed in the new version? ` class DiffusionPropagate(nn.Module): def init(self, prob_matrix, niter): super(DiffusionPropagate, self).init()

    self.niter = niter 

    if sp.isspmatrix(prob_matrix):
        prob_matrix = prob_matrix.toarray()

    self.register_buffer('prob_matrix', torch.FloatTensor(prob_matrix))

def forward(self, preds, seed_idx):
    # import ipdb; ipdb.set_trace()
    # prop_preds = torch.ones((preds.shape[0], preds.shape[1])).to(device)
    device = preds.device

    for i in range(preds.shape[0]):
        prop_pred = preds[i]
        for j in range(self.niter):
            P2 = self.prob_matrix.T * prop_pred.view((1, -1)).expand(self.prob_matrix.shape)
            P3 = torch.ones(self.prob_matrix.shape).to(device) - P2
            prop_pred = torch.ones((self.prob_matrix.shape[0], )).to(device) - torch.prod(P3, dim=1)
            # prop_pred[seed_idx[seed_idx[:,0] == i][:, 1]] = 1
            prop_pred = prop_pred.unsqueeze(0)
        if i == 0:
            prop_preds = prop_pred
        else:
            prop_preds = torch.cat((prop_preds, prop_pred), 0)

    return prop_preds`
xianggebenben commented 2 months ago

That is exactly what you mean.