FenTechSolutions / CausalDiscoveryToolbox

Package for causal inference in graphs and in the pairwise settings. Tools for graph structure recovery and dependencies are included.
https://fentechsolutions.github.io/CausalDiscoveryToolbox/html/index.html
MIT License
1.08k stars 198 forks source link

[BUG] CGNN run() Wrong way to calculate the score #149

Open Anditty opened 1 year ago

Anditty commented 1 year ago

Code:

 dataloader = DataLoader(dataset, batch_size=self.batch_size,
                                shuffle=True, drop_last=True,
                                num_workers=dataloader_workers)

        with trange(train_epochs + test_epochs, disable=not verbose) as t:
            for epoch in t:
                for i, data in enumerate(dataloader):
                    optim.zero_grad()
                    generated_data = self.forward()

                    mmd = self.criterion(generated_data, data)
                    if not epoch % 200 and i == 0:
                        t.set_postfix(idx=idx, loss=mmd.item())
                    mmd.backward()
                    optim.step()
                    if epoch >= test_epochs:
                        self.score.add_(mmd.data)

        return self.score.cpu().numpy() / test_epochs

Should change if epoch >= test_epochs: to if epoch >= train_epochs: to calculate the average score of the test set.

diviyank commented 1 year ago

Hello ! Yes indeed, I'll change this asap.