Closed waityousea closed 2 months ago
Thank you for the interest in our GraphSL library. The current code of the SLVAE is revised based on the source code, as shown in the version log. We modified the encoder of the SLVAE because it demonstrates better performance than the previous version. Please let me know if you have any questions.
I know you made modifications and merges in the VAE section of the code.
Is this part of the code no longer needed in the new version? ` class DiffusionPropagate(nn.Module): def init(self, prob_matrix, niter): super(DiffusionPropagate, self).init()
self.niter = niter
if sp.isspmatrix(prob_matrix):
prob_matrix = prob_matrix.toarray()
self.register_buffer('prob_matrix', torch.FloatTensor(prob_matrix))
def forward(self, preds, seed_idx):
# import ipdb; ipdb.set_trace()
# prop_preds = torch.ones((preds.shape[0], preds.shape[1])).to(device)
device = preds.device
for i in range(preds.shape[0]):
prop_pred = preds[i]
for j in range(self.niter):
P2 = self.prob_matrix.T * prop_pred.view((1, -1)).expand(self.prob_matrix.shape)
P3 = torch.ones(self.prob_matrix.shape).to(device) - P2
prop_pred = torch.ones((self.prob_matrix.shape[0], )).to(device) - torch.prod(P3, dim=1)
# prop_pred[seed_idx[seed_idx[:,0] == i][:, 1]] = 1
prop_pred = prop_pred.unsqueeze(0)
if i == 0:
prop_preds = prop_pred
else:
prop_preds = torch.cat((prop_preds, prop_pred), 0)
return prop_preds`
That is exactly what you mean.
The new version GraphSL’s code of the slvae model is different of the old. Meanwhile, it is different of the https://github.com/triplej0079/SLVAE.