salesforce / causalai

Salesforce CausalAI Library: A Fast and Scalable framework for Causal Analysis of Time Series and Tabular Data
BSD 3-Clause "New" or "Revised" License
254 stars 28 forks source link

converting results of PC discovery to graph chart #16

Closed priamai closed 8 months ago

priamai commented 8 months ago

Hi there, I am following the example code and I think there is a missing functionality to make it easier to use. Let me explain, at some point in the code we do this:


from causalai.models.time_series.pc import PCSingle, PC
from causalai.models.common.CI_tests.partial_correlation import PartialCorrelation
from causalai.models.common.CI_tests.kci import KCI
from causalai.data.data_generator import DataGenerator, GenerateRandomTimeseriesSEM
from causalai.models.common.CI_tests.discrete_ci_tests import DiscreteCI_tests

# also importing data object, data transform object, and prior knowledge object, and the graph plotting function
from causalai.data.time_series import TimeSeriesData
from causalai.data.transforms.time_series import StandardizeTransform
from causalai.models.common.prior_knowledge import PriorKnowledge
from causalai.misc.misc import plot_graph, get_precision_recall

tic = time.time()
result = pc_single.run(target_var=target_var, pvalue_thres=pvalue_thres, max_lag=max_lag, max_condition_set_size=None)

toc = time.time()
print(f'Time taken: {toc-tic:.2f}s\n')

for key in result.keys():
    print(key, '\n', result[key])
    print()

The result is a dict with parents and p_dict and value keys. I now want to plot the DAG like in the other example code:

plot_graph(result, node_size=1000)

for i, n in enumerate(var_names):
    plt.plot(data_trans[-100:,i], label=n)
plt.legend()
plt.legend()
plt.show()

This fails of course because the structure doesn't follow the schema of the example with graph_gt:

data_array, var_names, graph_gt = DataGenerator(sem, T=T, seed=0)

So my question is there an utility function that converts the result object dict into a graph_gt like dict?

Thanks.

priamai commented 8 months ago

Never mind found the solution is simple:

# print estimated causal graph
graph_est={n:[] for n in result.keys()}
for key in result.keys():
    parents = result[key]['parents']
    graph_est[key].extend(parents)
    print(f'{key}: {parents}')

and then

plot_graph(graph_est, node_size=1000)

for i, n in enumerate(var_names):
    plt.plot(data_trans[-100:,i], label=n)
plt.legend()
plt.legend()
plt.show()