Closed Tianbaojie closed 7 months ago
def train(self):
time_total_start = time.time()
wandb.log({"epochs": self.epoch})
for epoch in range(self.epoch):
t = time.time()
self.model.train()
self.optimizer.zero_grad()
adj_recon, mu, logvar = self.model(self.data.x, self.all_edge_index)
loss = self.criterion.loss_fn(mu=mu,
logvar=logvar,
pos_edge_index=self.data.train_pos_edge_index,
all_edge_index=self.all_edge_index)
loss.backward()
self.optimizer.step()
if (epoch + 1) % 10 == 0:
self.model.eval()
roc_auc, ap = self.criterion.single_test(
mu = mu,
logvar=logvar,
test_pos_edge_index=self.data.test_pos_edge_index,
test_neg_edge_index=self.data.test_neg_edge_index
)
self.__metric_log(
loss = loss.item(),
roc = roc_auc,
ap = ap,
time_lapsed = time.time() - t
)
epoch_str = str(epoch + 1).zfill(3) if epoch + 1 < 100 else str(epoch + 1)
time_elapsed = time.time() - t
formatted_output = (
f"TESTING | AT epoch {epoch_str}, "
f"loss: {loss.item():.4f}, "
f"ROC_AUC: {roc_auc:.4f}, "
f"AP: {ap:.4f}, "
f"Time: {time_elapsed:.4f}"
)
print(formatted_output)
print(f"The total time taken: {(time.time() - time_total_start):.4f}")
Hey, sorry for the very late reply. But, if all the train and test samples are inside the data
object. Now here, I might provide all the edge indices but also at the same time, I am providing the train position indices that does the required masking to not leak any evaluation data. However this is also a problem in transductive learning, where we can not do complete separation of training data while we are doing evaluation as there is a dependency of that.
Making this issue close due to no activity
hi, your code is causing label leakage. You used the test and validation sets and edge information during training
wandb_train.py self.all_edge_index include test edge,val edge def train(self): time_total_start = time.time() wandb.log({"epochs": self.epoch}) for epoch in range(self.epoch): t = time.time() self.model.train() self.optimizer.zero_grad() adj_recon, mu, logvar = self.model(self.data.x, self.all_edge_index) loss = self.criterion.loss_fn(mu=mu, logvar=logvar, pos_edge_index=self.data.train_pos_edge_index, all_edge_index=self.all_edge_index) loss.backward() self.optimizer.step()