Hello! I've been trying to replicate the results on Sachs using the provided hyperparameters, but I'm getting SHD ~37-40 instead of the low 10s. Any clue why?
from cdt.data import load_dataset
data, graph = load_dataset("sachs")
data = data.to_numpy()
graph = nx.to_numpy_array(graph)
num_nodes = data.shape[1]
model = DiffAN(num_nodes, residue=True)
pred_graph, order = model.fit(data)
metrics = MetricsDAG(pred_graph, graph).metrics
Hello! I've been trying to replicate the results on Sachs using the provided hyperparameters, but I'm getting SHD ~37-40 instead of the low 10s. Any clue why?
This produces