FenTechSolutions / CausalDiscoveryToolbox

Package for causal inference in graphs and in the pairwise settings. Tools for graph structure recovery and dependencies are included.
https://fentechsolutions.github.io/CausalDiscoveryToolbox/html/index.html
MIT License
1.12k stars 197 forks source link

Cannot reproduce SAM paper results #87

Open lagph opened 3 years ago

lagph commented 3 years ago

I'm unable to reproduce the results on synthetic data, and also on the Sachs data, that are reported in the SAM paper.

This code:

from cdt.data import AcyclicGraphGenerator
import numpy as np
from cdt.causality.graph import SAM
from sklearn.metrics import average_precision_score
import networkx as nx

generator = AcyclicGraphGenerator("gp_mix")

data, graph = generator.generate()

sam = SAM(
        train_epochs=3000,
        test_epochs=300,
        dlr=0.001,
        dagpenalization_increase=0.01,
        gpus=1,
        nruns=8,
        njobs=2,
        verbose=True,
        lambda2=0.001,
        lambda1=10,
        nh=20,
        dnh=200,
)

prediction = sam.predict(data)

predicted_adj = nx.adjacency_matrix(prediction).todense()
graph_adj = nx.adjacency_matrix(graph).todense()

print(average_precision_score(np.ravel(graph_adj), np.ravel(predicted_adj)))

gave 0.07 when I ran it. Experimenting a bit, I see values between 0.07 and 0.25 on this data. The paper suggests I should be getting 0.7 here (note that I haven't used cdt.metrics.precision_recall, due to #85).

Similarly, with the Sachs data:

from cdt.data import load_dataset
import numpy as np
from cdt.causality.graph import SAM
from cdt.metrics import precision_recall
import networkx as nx

data, graph = load_dataset('sachs')

sam = SAM(
        train_epochs=3000,
        test_epochs=300,
        dlr=0.001,
        dagpenalization_increase=0.01,
        gpus=1,
        nruns=8,
        njobs=2,
        verbose=True,
        lambda2=0.001,
        lambda1=10,
        nh=20,
        dnh=200,
)

prediction = sam.predict(data)

average_precision, _ = precision_recall(graph, prediction)
print(average_precision)

gives 0.17, where going by the paper I'd expect to see ~0.45. Please could you post a snippet showing how I can achieve the results reported in the SAM paper using cdt? Thanks.

diviyank commented 3 years ago

Hi, the model has been updated/fixed, could you try again? Please use 32 runs for the execution and the default parameters (they should be correct) Best, Diviyan