py-why / causal-learn

Causal Discovery in Python. It also includes (conditional) independence tests and score functions.
https://causal-learn.readthedocs.io/en/latest/
MIT License
1.04k stars 174 forks source link

TestFCI failed in graph_utils.adj_precision #149

Closed winstonyu closed 5 months ago

winstonyu commented 8 months ago

I tried to run TestFCI without modifications, but got errors like below. Any ideas what's going on? Thanks.

=======================================================================

/Users/xxx/opt/anaconda3/envs/python311/bin/python /Applications/PyCharm CE.app/Contents/plugins/python-ce/helpers/pycharm/_jb_unittest_runner.py --target TestFCI.TestFCI Testing started at 16:17 ... Launching unittests with arguments python -m unittest TestFCI.TestFCI in /Users/xxx/PycharmProjects/pythonProject/causal-learn/tests

Depth=0, working on node 7: 100%|██████████| 8/8 [00:00<00:00, 522.64it/s] Depth=0, working on node 4: 100%|██████████| 5/5 [00:00<00:00, 822.09it/s] Depth=0, working on node 4: 100%|██████████| 5/5 [00:00<00:00, 1059.54it/s] X3 --> X4 X3 --> X5 Depth=0, working on node 10: 100%|██████████| 11/11 [00:00<00:00, 516.45it/s] X4 --> X2 X8 --> X2 X8 --> X3 X9 --> X3 X8 --> X4 X9 --> X4 X8 --> X5 X9 --> X5 X8 --> X9 Depth=0, working on node 5: 100%|██████████| 6/6 [00:00<00:00, 694.36it/s] Depth=0, working on node 36: 100%|██████████| 37/37 [00:00<00:00, 150.56it/s] X5 --> X2 X5 --> X3 X7 --> X36 X35 --> X9 X35 --> X10 X35 --> X12 X31 --> X16 X33 --> X16 X31 --> X18 X20 --> X21 X32 --> X20 X24 --> X21 X30 --> X26 X29 --> X30 X30 --> X31 X31 --> X32 X32 --> X33 X33 --> X34 X34 --> X35 X35 --> X36 Depth=0, working on node 47: 100%|██████████| 48/48 [00:00<00:00, 113.35it/s] X4 --> X11 X7 --> X8 X8 --> X14 X8 --> X16 X8 --> X25 X10 --> X12 X10 --> X13 X11 --> X12 X13 --> X14 X14 --> X15 X15 --> X24 X19 --> X36 X20 --> X21 X20 --> X34 X39 --> X20 X21 --> X41 X24 --> X25 X28 --> X33 X32 --> X33 X34 --> X44 X38 --> X42 Depth=0, working on node 19: 100%|██████████| 20/20 [00:00<00:00, 334.04it/s] X2 --> X8 X16 --> X2 X17 --> X2 X3 --> X9 X17 --> X3 X5 --> X11 X19 --> X5 X6 --> X13 X12 --> X14 X12 --> X16 X12 --> X17 X12 --> X19 Depth=0, working on node 26: 100%|██████████| 27/27 [00:00<00:00, 203.18it/s] X2 --> X1 X2 --> X14 X3 --> X9 X3 --> X18 X4 --> X10 X4 --> X18 X4 --> X19 X4 --> X27 X8 --> X6 X6 --> X15 X9 --> X7 X7 --> X24 X10 --> X8 X8 --> X21 X8 --> X23 X8 --> X26 X9 --> X12 X9 --> X17 X9 --> X25 X13 --> X10 X13 --> X27 X15 --> X20 X21 --> X20 X24 --> X23 X25 --> X24 Depth=0, working on node 31: 100%|██████████| 32/32 [00:00<00:00, 174.52it/s] X10 --> X2 X2 --> X11 X2 --> X12 X18 --> X10 X10 --> X19 X11 --> X19 X12 --> X20 X18 --> X26 X19 --> X27 X19 --> X30 X20 --> X28 X20 --> X31 X21 --> X29 X23 --> X31 X24 --> X32 Depth=0, working on node 55: 100%|██████████| 56/56 [00:00<00:00, 89.87it/s] X8 --> X9 X8 --> X13 X12 --> X13 X15 --> X17 X15 --> X20 X15 --> X25 X17 --> X19 X25 --> X44 X35 --> X36 X40 --> X42 X42 --> X43 X43 --> X44 Depth=0, working on node 69: 100%|██████████| 70/70 [00:00<00:00, 82.04it/s] X7 --> X8 X8 --> X9 X13 --> X14 X20 --> X13 X14 --> X24 X14 --> X29 X14 --> X32 X14 --> X33 X14 --> X49 X14 --> X53 X14 --> X54 X14 --> X70 X18 --> X19 X19 --> X37 X19 --> X39 X19 --> X59 X19 --> X61 X19 --> X64 X19 --> X65 X19 --> X68 X19 --> X70 X22 --> X30 X39 --> X40 Depth=0, working on node 75: 100%|██████████| 76/76 [00:00<00:00, 78.02it/s] X3 --> X29 X9 --> X15 X15 --> X17 X17 --> X30 X17 --> X66 X21 --> X30 X25 --> X30 X31 --> X39 X31 --> X43 X31 --> X46 X31 --> X48 X31 --> X54 X31 --> X75 X39 --> X72 X46 --> X49 X48 --> X49 X54 --> X55 X71 --> X72 Depth=0, working on node 222: 100%|██████████| 223/223 [00:09<00:00, 24.69it/s] X24 --> X26 X26 --> X28 X30 --> X133 X32 --> X36 X35 --> X36 X38 --> X39 X39 --> X54 X41 --> X44 X46 --> X48 X48 --> X50 X50 --> X52 X50 --> X66 X52 --> X54 X52 --> X84 X52 --> X89 X54 --> X59 X59 --> X60 X59 --> X127 X62 --> X144 X66 --> X68 X66 --> X74 X70 --> X72 X72 --> X74 X78 --> X80 X81 --> X82 X84 --> X124 X87 --> X128 X91 --> X92 X91 --> X129 X92 --> X129 X96 --> X101 X105 --> X106 X110 --> X112 X112 --> X114 X114 --> X116 X116 --> X118 X116 --> X131 X118 --> X120 X118 --> X149 X118 --> X151 X120 --> X122 X122 --> X123 X124 --> X125 X131 --> X133 X131 --> X144 X131 --> X147 X131 --> X153 X133 --> X153 X144 --> X153 X151 --> X223 X155 --> X157 X157 --> X158 X158 --> X159 X161 --> X163 X161 --> X164 X163 --> X194 X163 --> X195 X163 --> X196 X163 --> X208 X163 --> X222 X164 --> X166 X164 --> X168 X164 --> X170 X164 --> X186 X164 --> X193 X166 --> X174 X166 --> X175 X168 --> X171 X174 --> X179 X174 --> X181 X179 --> X183 X186 --> X187 X186 --> X206 X187 --> X191 X194 --> X199 X194 --> X200 X199 --> X203 X200 --> X201 X206 --> X223 X219 --> X221 X222 --> X223 Depth=0, working on node 19: 100%|██████████| 20/20 [00:00<00:00, 454.79it/s] X3 --> X6 X3 --> X10 X4 --> X6 X10 --> X12 X10 --> X16 X12 --> X15 X15 --> X19 X16 --> X20 [(3, 13), (4, 11), (6, 1), (8, 4), (9, 1), (9, 5), (10, 2), (11, 7), (12, 1), (12, 2), (13, 4), (13, 5), (14, 0), (14, 8), (14, 13)] /Users/xxx/PycharmProjects/pythonProject/causal-learn/causallearn/search/ConstraintBased/FCI.py:736: UserWarning: The number of features is much larger than the sample size! warnings.warn("The number of features is much larger than the sample size!") Depth=0, working on node 9: 100%|██████████| 10/10 [00:00<00:00, 956.95it/s] X5 --> X8 Graph Nodes: X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15

Graph Edges:

  1. X15 --> X1
  2. X7 --> X2
  3. X10 --> X2
  4. X13 --> X2
  5. X11 --> X3
  6. X13 --> X3
  7. X4 --> X14
  8. X9 --> X5
  9. X5 --> X12
  10. X14 --> X5
  11. X10 --> X6
  12. X14 --> X6
  13. X12 --> X8
  14. X15 --> X9
  15. X15 --> X14

pag: Graph Nodes: X1;X7;X6;X3;X2;X8;X5;X10;X9;X4

Graph Edges:

  1. X1 o-> X6
  2. X1 o-> X5
  3. X1 o-o X9
  4. X7 o-> X2
  5. X5 o-> X6
  6. X10 o-> X6
  7. X9 o-> X6
  8. X4 o-> X6
  9. X3 o-> X2
  10. X10 o-> X2
  11. X5 --> X8
  12. X9 o-> X5
  13. X4 o-> X5

fci graph: Graph Nodes: X1;X2;X3;X4;X5;X6;X7;X8;X9;X10

Graph Edges:

  1. X1 o-> X5
  2. X1 o-> X6
  3. X1 o-o X9
  4. X3 o-> X2
  5. X7 o-> X2
  6. X10 o-> X2
  7. X4 o-> X5
  8. X4 o-> X6
  9. X5 o-> X6
  10. X5 --> X8
  11. X9 o-> X5
  12. X9 o-> X6
  13. X10 o-> X6

fci(data, d_separation, 0.05):

Error Traceback (most recent call last): File "/Users/xxx/PycharmProjects/pythonProject/causal-learn/tests/TestFCI.py", line 213, in test_er_graph self.run_simulate_data_test(pag, G) File "/Users/xxx/PycharmProjects/pythonProject/causal-learn/tests/TestFCI.py", line 154, in run_simulate_data_test arrow_precision = graph_utils.arrow_precision(truth, est) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/xxx/PycharmProjects/pythonProject/causal-learn/causallearn/utils/GraphUtils.py", line 431, in arrow_precision confusion = ArrowConfusion(truth, est) ^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/xxx/PycharmProjects/pythonProject/causal-learn/causallearn/graph/ArrowConfusion.py", line 43, in init if truth.get_endpoint(truth.get_node(nodes_name[i]), truth.get_node(nodes_name[j])) == Endpoint.ARROW: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/xxx/PycharmProjects/pythonProject/causal-learn/causallearn/graph/Endpoint.py", line 23, in eq return self.value == other.value ^^^^^^^^^^^ AttributeError: 'NoneType' object has no attribute 'value'

Depth=0, working on node 6: 100%|██████████| 7/7 [00:00<00:00, 1045.92it/s]

fci(data, d_separation, 0.05):

Error Traceback (most recent call last): File "/Users/xxx/PycharmProjects/pythonProject/causal-learn/tests/TestFCI.py", line 147, in test_fritl self.run_simulate_data_test(pag, G) File "/Users/xxx/PycharmProjects/pythonProject/causal-learn/tests/TestFCI.py", line 154, in run_simulate_data_test arrow_precision = graph_utils.arrow_precision(truth, est) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/xxx/PycharmProjects/pythonProject/causal-learn/causallearn/utils/GraphUtils.py", line 431, in arrow_precision confusion = ArrowConfusion(truth, est) ^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/xxx/PycharmProjects/pythonProject/causal-learn/causallearn/graph/ArrowConfusion.py", line 43, in init if truth.get_endpoint(truth.get_node(nodes_name[i]), truth.get_node(nodes_name[j])) == Endpoint.ARROW: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/xxx/PycharmProjects/pythonProject/causal-learn/causallearn/graph/Endpoint.py", line 23, in eq return self.value == other.value ^^^^^^^^^^^ AttributeError: 'NoneType' object has no attribute 'value'

Depth=0, working on node 3: 100%|██████████| 4/4 [00:00<00:00, 2099.78it/s]

fci(data, d_separation, 0.05):

Error Traceback (most recent call last): File "/Users/xxx/PycharmProjects/pythonProject/causal-learn/tests/TestFCI.py", line 81, in test_simple_test self.run_simulate_data_test(pag, G) File "/Users/xxx/PycharmProjects/pythonProject/causal-learn/tests/TestFCI.py", line 154, in run_simulate_data_test arrow_precision = graph_utils.arrow_precision(truth, est) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/xxx/PycharmProjects/pythonProject/causal-learn/causallearn/utils/GraphUtils.py", line 431, in arrow_precision confusion = ArrowConfusion(truth, est) ^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/xxx/PycharmProjects/pythonProject/causal-learn/causallearn/graph/ArrowConfusion.py", line 43, in init if truth.get_endpoint(truth.get_node(nodes_name[i]), truth.get_node(nodes_name[j])) == Endpoint.ARROW: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/xxx/PycharmProjects/pythonProject/causal-learn/causallearn/graph/Endpoint.py", line 23, in eq return self.value == other.value ^^^^^^^^^^^ AttributeError: 'NoneType' object has no attribute 'value'

Depth=0, working on node 6: 100%|██████████| 7/7 [00:00<00:00, 993.74it/s] X4 --> X1 X2 --> X5

fci(data, d_separation, 0.05):

Error Traceback (most recent call last): File "/Users/xxx/PycharmProjects/pythonProject/causal-learn/tests/TestFCI.py", line 107, in test_simple_test2 self.run_simulate_data_test(pag, G) File "/Users/xxx/PycharmProjects/pythonProject/causal-learn/tests/TestFCI.py", line 154, in run_simulate_data_test arrow_precision = graph_utils.arrow_precision(truth, est) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/xxx/PycharmProjects/pythonProject/causal-learn/causallearn/utils/GraphUtils.py", line 431, in arrow_precision confusion = ArrowConfusion(truth, est) ^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/xxx/PycharmProjects/pythonProject/causal-learn/causallearn/graph/ArrowConfusion.py", line 43, in init if truth.get_endpoint(truth.get_node(nodes_name[i]), truth.get_node(nodes_name[j])) == Endpoint.ARROW: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/xxx/PycharmProjects/pythonProject/causal-learn/causallearn/graph/Endpoint.py", line 23, in eq return self.value == other.value ^^^^^^^^^^^ AttributeError: 'NoneType' object has no attribute 'value'

Depth=0, working on node 4: 100%|██████████| 5/5 [00:00<00:00, 1673.84it/s] X3 --> X4 X3 --> X5

fci(data, d_separation, 0.05):

Error Traceback (most recent call last): File "/Users/xxx/PycharmProjects/pythonProject/causal-learn/tests/TestFCI.py", line 127, in test_simple_test3 self.run_simulate_data_test(pag, G) File "/Users/xxx/PycharmProjects/pythonProject/causal-learn/tests/TestFCI.py", line 154, in run_simulate_data_test arrow_precision = graph_utils.arrow_precision(truth, est) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/xxx/PycharmProjects/pythonProject/causal-learn/causallearn/utils/GraphUtils.py", line 431, in arrow_precision confusion = ArrowConfusion(truth, est) ^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/xxx/PycharmProjects/pythonProject/causal-learn/causallearn/graph/ArrowConfusion.py", line 43, in init if truth.get_endpoint(truth.get_node(nodes_name[i]), truth.get_node(nodes_name[j])) == Endpoint.ARROW: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/xxx/PycharmProjects/pythonProject/causal-learn/causallearn/graph/Endpoint.py", line 23, in eq return self.value == other.value ^^^^^^^^^^^ AttributeError: 'NoneType' object has no attribute 'value'

Ran 7 tests in 232.656s

FAILED (errors=5)

Process finished with exit code 1

priamai commented 7 months ago

Can you show me the code exactly from the unit test?

winstonyu commented 7 months ago

Can you show me the code exactly from the unit test?

import hashlib
import os
import random
import sys
sys.path.append("")
import time
import unittest

from networkx import DiGraph, erdos_renyi_graph, is_directed_acyclic_graph
import numpy as np
import pandas as pd

from causallearn.graph.Dag import Dag
from causallearn.graph.GraphNode import GraphNode
from causallearn.search.ConstraintBased.FCI import fci
from causallearn.utils.cit import chisq, fisherz, kci, d_separation
from causallearn.utils.DAG2PAG import dag2pag
from causallearn.utils.GraphUtils import GraphUtils
from causallearn.utils.PCUtils.BackgroundKnowledge import BackgroundKnowledge

######################################### Test Notes ###########################################
# All the benchmark results of loaded files (e.g. "./TestData/benchmark_returned_results/")    #
# are obtained from the code of causal-learn as of commit                                      #
# https://github.com/py-why/causal-learn/commit/5918419 (02-03-2022).                          #
#                                                                                              #
# We are not sure if the results are completely "correct" (reflect ground truth graph) or not. #
# So if you find your tests failed, it means that your modified code is logically inconsistent #
# with the code as of 5918419, but not necessarily means that your code is "wrong".            #
# If you are sure that your modification is "correct" (e.g. fixed some bugs in 5918419),       #
# please report it to us. We will then modify these benchmark results accordingly. Thanks :)   #
######################################### Test Notes ###########################################

BENCHMARK_TXTFILE_TO_MD5 = {
    "tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_asia_fci_chisq_0.05.txt": "65f54932a9d8224459e56c40129e6d8b",
    "tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_cancer_fci_chisq_0.05.txt": "0312381641cb3b4818e0c8539f74e802",
    "tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_earthquake_fci_chisq_0.05.txt": "a1160b92ce15a700858552f08e43b7de",
    "tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_sachs_fci_chisq_0.05.txt": "dced4a202fc32eceb75f53159fc81f3b",
    "tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_survey_fci_chisq_0.05.txt": "b1a28eee1e0c6ea8a64ac1624585c3f4",
    "tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_alarm_fci_chisq_0.05.txt": "c3bbc2b8aba456a4258dd071a42085bc",
    "tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_barley_fci_chisq_0.05.txt": "4a5000e7a582083859ee6aef15073676",
    "tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_child_fci_chisq_0.05.txt": "6b7858589e12f04b0f489ba4589a1254",
    "tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_insurance_fci_chisq_0.05.txt": "9975942b936aa2b1fc90c09318ca2d08",
    "tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_water_fci_chisq_0.05.txt": "48eee804d59526187b7ecd0519556ee5",
    "tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_hailfinder_fci_chisq_0.05.txt": "6b9a6b95b6474f8530e85c022f4e749c",
    "tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_hepar2_fci_chisq_0.05.txt": "4aae21ff3d9aa2435515ed2ee402294c",
    "tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_win95pts_fci_chisq_0.05.txt": "648fdf271e1440c06ca2b31b55ef1f3f",
    "tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_andes_fci_chisq_0.05.txt": "04092ae93e54c727579f08bf5dc34c77",
    "tests/TestData/benchmark_returned_results/linear_10_fci_fisherz_0.05.txt": "289c86f9c665bf82bbcc4c9e1dcec3e7"
}
#
INCONSISTENT_RESULT_GRAPH_ERRMSG = "Returned graph is inconsistent with the benchmark. Please check your code with the commit 5918419."
INCONSISTENT_RESULT_GRAPH_WITH_PAG_ERRMSG = "Returned graph is inconsistent with the truth PAG."

# verify files integrity first
for file_path, expected_MD5 in BENCHMARK_TXTFILE_TO_MD5.items():
    with open(file_path, 'rb') as fin:
        assert hashlib.md5(fin.read()).hexdigest() == expected_MD5, \
            f'{file_path} is corrupted. Please download it again from https://github.com/cmu-phil/causal-learn/blob/5918419/tests/TestData'

def gen_coef():
    return np.random.uniform(1, 3)

class TestFCI(unittest.TestCase):
    def test_simple_test(self):
        data = np.empty(shape=(0, 4))
        true_dag = DiGraph()
        ground_truth_edges = [(0, 1), (0, 2), (1, 3), (2, 3)]
        true_dag.add_edges_from(ground_truth_edges)
        G, edges = fci(data, d_separation, 0.05, verbose=False, true_dag=true_dag)

        ground_truth_nodes = []
        for i in range(4):
            ground_truth_nodes.append(GraphNode(f'X{i + 1}'))
        ground_truth_dag = Dag(ground_truth_nodes)
        for u, v in ground_truth_edges:
            ground_truth_dag.add_directed_edge(ground_truth_nodes[u], ground_truth_nodes[v])
        pag = dag2pag(ground_truth_dag, [])

        print(f'fci(data, d_separation, 0.05):')
        self.run_simulate_data_test(pag, G)

        nodes = G.get_nodes()
        assert G.is_adjacent_to(nodes[0], nodes[1])

        bk = BackgroundKnowledge().add_forbidden_by_node(nodes[0], nodes[1]).add_forbidden_by_node(nodes[1], nodes[0])
        G_with_background_knowledge, edges = fci(data, d_separation, 0.05, verbose=False, true_dag=true_dag,
                                                 background_knowledge=bk)
        assert not G_with_background_knowledge.is_adjacent_to(nodes[0], nodes[1])

    def test_simple_test2(self):
        data = np.empty(shape=(0, 7))
        true_dag = DiGraph()
        ground_truth_edges = [(7, 0), (7, 1), (8, 3), (8, 4), (2, 5), (2, 6), (5, 1), (6, 3), (3, 0), (1, 4)]
        true_dag.add_edges_from(ground_truth_edges)
        G, edges = fci(data, d_separation, 0.05, verbose=False, true_dag=true_dag)
        ground_truth_nodes = []
        for i in range(9):
            ground_truth_nodes.append(GraphNode(f'X{i + 1}'))
        ground_truth_dag = Dag(ground_truth_nodes)
        for u, v in ground_truth_edges:
            ground_truth_dag.add_directed_edge(ground_truth_nodes[u], ground_truth_nodes[v])

        pag = dag2pag(ground_truth_dag, ground_truth_nodes[7: 9])

        print(f'fci(data, d_separation, 0.05):')
        self.run_simulate_data_test(pag, G)

    def test_simple_test3(self):

        data = np.empty(shape=(0, 5))
        true_dag = DiGraph()
        ground_truth_edges = [(0, 2), (1, 2), (2, 3), (2, 4)]
        true_dag.add_edges_from(ground_truth_edges)
        G, edges = fci(data, d_separation, 0.05, verbose=False, true_dag=true_dag)

        ground_truth_nodes = []
        for i in range(5):
            ground_truth_nodes.append(GraphNode(f'X{i + 1}'))
        ground_truth_dag = Dag(ground_truth_nodes)
        for u, v in ground_truth_edges:
            ground_truth_dag.add_directed_edge(ground_truth_nodes[u], ground_truth_nodes[v])

        pag = dag2pag(ground_truth_dag, [])

        print(f'fci(data, d_separation, 0.05):')
        self.run_simulate_data_test(pag, G)

    def test_fritl(self):
        data = np.empty(shape=(0, 7))
        true_dag = DiGraph()
        ground_truth_edges = [(7, 0), (7, 5), (8, 0), (8, 6), (9, 3), (9, 4), (9, 6),
                              (0, 1), (0, 2), (1, 2), (2, 4), (5, 6)]
        true_dag.add_edges_from(ground_truth_edges)
        G, edges = fci(data, d_separation, 0.05, verbose=False, true_dag=true_dag)

        ground_truth_nodes = []
        for i in range(10):
            ground_truth_nodes.append(GraphNode(f'X{i + 1}'))
        ground_truth_dag = Dag(ground_truth_nodes)
        for u, v in ground_truth_edges:
            ground_truth_dag.add_directed_edge(ground_truth_nodes[u], ground_truth_nodes[v])

        pag = dag2pag(ground_truth_dag, ground_truth_nodes[7: 10])

        print(f'fci(data, d_separation, 0.05):')
        self.run_simulate_data_test(pag, G)

    @staticmethod
    def run_simulate_data_test(truth, est):
        graph_utils = GraphUtils()
        adj_precision = graph_utils.adj_precision(truth, est)
        adj_recall = graph_utils.adj_recall(truth, est)
        arrow_precision = graph_utils.arrow_precision(truth, est)
        arrow_recall = graph_utils.adj_precision(truth, est)

        print(f'adj_precision: {adj_precision}')
        print(f'adj_recall: {adj_recall}')
        print(f'arrow_precision: {arrow_precision}')
        print(f'arrow_recall: {arrow_recall}')
        print()
        assert np.isclose([adj_precision, adj_recall, arrow_precision, arrow_recall], [1.0, 1.0, 1.0, 1.0]).all()

    def test_bnlearn_discrete_datasets(self):
        benchmark_names = [
            "asia", "cancer", "earthquake", "sachs", "survey",
            "alarm", "barley", "child", "insurance", "water",
            "hailfinder", "hepar2", "win95pts",
            "andes"
        ]

        bnlearn_path = 'tests/TestData/bnlearn_discrete_10000/data'
        for bname in benchmark_names:
            data = np.loadtxt(os.path.join(bnlearn_path, f'{bname}.txt'), skiprows=1)
            G, edges = fci(data, chisq, 0.05, verbose=False)
            benchmark_returned_graph = np.loadtxt(
                f'tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_{bname}_fci_chisq_0.05.txt')
            assert np.all(G.graph == benchmark_returned_graph), INCONSISTENT_RESULT_GRAPH_ERRMSG

    def test_continuous_dataset(self):
        data = np.loadtxt('tests/data_linear_10.txt', skiprows=1)
        G, edges = fci(data, fisherz, 0.05, verbose=False)
        benchmark_returned_graph = np.loadtxt(
            f'tests/TestData/benchmark_returned_results/linear_10_fci_fisherz_0.05.txt')
        assert np.all(G.graph == benchmark_returned_graph), INCONSISTENT_RESULT_GRAPH_ERRMSG

    def test_er_graph(self):
        random.seed(42)
        np.random.seed(42)
        p = 0.1
        for _ in range(5):
            data = np.empty(shape=(0, 10))
            true_dag = erdos_renyi_graph(15, p, directed=True)  # The last 5 variables are latent variables
            while not is_directed_acyclic_graph(true_dag):
                true_dag = erdos_renyi_graph(15, p, directed=True)
            ground_truth_edges = list(true_dag.edges)
            print(ground_truth_edges)
            G, edges = fci(data, d_separation, 0.05, verbose=False, true_dag=true_dag)

            ground_truth_nodes = []
            for i in range(15):
                ground_truth_nodes.append(GraphNode(f'X{i + 1}'))
            ground_truth_dag = Dag(ground_truth_nodes)
            for u, v in ground_truth_edges:
                ground_truth_dag.add_directed_edge(ground_truth_nodes[u], ground_truth_nodes[v])
            print(ground_truth_dag)
            pag = dag2pag(ground_truth_dag, ground_truth_nodes[10:])
            print('pag:')
            print(pag)
            print('fci graph:')
            print(G)
            print(f'fci(data, d_separation, 0.05):')
            self.run_simulate_data_test(pag, G)
MarkDana commented 7 months ago

Thank you @winstonyu for helping us identify this! This issue is from the endpoint comparison. A patch is updated in https://github.com/py-why/causal-learn/pull/154, and the issue here should be addressed.