I'm unable to reproduce the results on synthetic data, and also on the Sachs data, that are reported in the SAM paper.
This code:
from cdt.data import AcyclicGraphGenerator
import numpy as np
from cdt.causality.graph import SAM
from sklearn.metrics import average_precision_score
import networkx as nx
generator = AcyclicGraphGenerator("gp_mix")
data, graph = generator.generate()
sam = SAM(
train_epochs=3000,
test_epochs=300,
dlr=0.001,
dagpenalization_increase=0.01,
gpus=1,
nruns=8,
njobs=2,
verbose=True,
lambda2=0.001,
lambda1=10,
nh=20,
dnh=200,
)
prediction = sam.predict(data)
predicted_adj = nx.adjacency_matrix(prediction).todense()
graph_adj = nx.adjacency_matrix(graph).todense()
print(average_precision_score(np.ravel(graph_adj), np.ravel(predicted_adj)))
gave 0.07 when I ran it. Experimenting a bit, I see values between 0.07 and 0.25 on this data. The paper suggests I should be getting 0.7 here (note that I haven't used cdt.metrics.precision_recall, due to #85).
Similarly, with the Sachs data:
from cdt.data import load_dataset
import numpy as np
from cdt.causality.graph import SAM
from cdt.metrics import precision_recall
import networkx as nx
data, graph = load_dataset('sachs')
sam = SAM(
train_epochs=3000,
test_epochs=300,
dlr=0.001,
dagpenalization_increase=0.01,
gpus=1,
nruns=8,
njobs=2,
verbose=True,
lambda2=0.001,
lambda1=10,
nh=20,
dnh=200,
)
prediction = sam.predict(data)
average_precision, _ = precision_recall(graph, prediction)
print(average_precision)
gives 0.17, where going by the paper I'd expect to see ~0.45. Please could you post a snippet showing how I can achieve the results reported in the SAM paper using cdt? Thanks.
Hi, the model has been updated/fixed, could you try again? Please use 32 runs for the execution and the default parameters (they should be correct)
Best,
Diviyan
I'm unable to reproduce the results on synthetic data, and also on the Sachs data, that are reported in the SAM paper.
This code:
gave 0.07 when I ran it. Experimenting a bit, I see values between 0.07 and 0.25 on this data. The paper suggests I should be getting 0.7 here (note that I haven't used cdt.metrics.precision_recall, due to #85).
Similarly, with the Sachs data:
gives 0.17, where going by the paper I'd expect to see ~0.45. Please could you post a snippet showing how I can achieve the results reported in the SAM paper using cdt? Thanks.