py-why / causal-learn

Causal Discovery in Python. It also includes (conditional) independence tests and score functions.
https://causal-learn.readthedocs.io/en/latest/
MIT License
1.13k stars 186 forks source link

FCI implementation #159

Open kenneth-lee-ch opened 8 months ago

kenneth-lee-ch commented 8 months ago

Is the current implementation of FCI algorithm complete according to [1]? The current version doesn't seem to include rules R5-R10 described in [1].

Reference: [1] Zhang, Jiji. "On the completeness of orientation rules for causal discovery in the presence of latent confounders and selection bias." Artificial Intelligence 172.16-17 (2008): 1873-1896.

kunwuz commented 8 months ago

No, it hasn't yet included the rules introduced in Jiji's paper.

kenneth-lee-ch commented 8 months ago

I find it more informative to inform others that the current implementation of FCI is not complete, so I will keep this issue open.

kenneth-lee-ch commented 7 months ago

I added the rules 8, 9, and 10 as follows using causal-learn syntax:


def is_possible_parent(graph: Graph, potential_parent_node, child_node):
    """
    Test if a node can possibly serve as parent of the given node.
    Make sure that on the connecting edge
        (a) there is no head edge-mark (->) at the tested node and
        (b) there is no tail edge-mark (--) at the given node,
    where variant edge-marks (o) are allowed.
    :param potential_parent_node: the node that is being tested
    :param child_node: the node that serves as the child
    :return:
     """
    if graph.node_map[potential_parent_node] == graph.node_map[child_node]:
        return False
    if not graph.is_adjacent_to(potential_parent_node, child_node):
        return False

    if graph.get_endpoint(child_node, potential_parent_node) == Endpoint.ARROW or \
    graph.get_endpoint(potential_parent_node, child_node) == Endpoint.TAIL:
        return False
    else:
        return True

def find_possible_children(graph: Graph, parent_node, en_nodes=None):
    if en_nodes is None:
        nodes = graph.get_nodes()
        en_nodes = [node for node in nodes if graph.node_map[node] != graph.node_map[parent_node]]

    potential_child_nodes = set()
    for potential_node in en_nodes:
        if is_possible_parent(graph, potential_parent_node=parent_node, child_node=potential_node):
            potential_child_nodes.add(potential_node)

    return potential_child_nodes

def find_possible_parents(graph: Graph,child_node, en_nodes=None):
    if en_nodes is None:
        nodes = graph.get_nodes()
        en_nodes = [node for node in nodes if graph.node_map[node] != graph.node_map[child_node]]

    possible_parents = [parent_node for parent_node in en_nodes if is_possible_parent(graph, parent_node, child_node)]

    return possible_parents

def existsSemidirectedPath(node_from: Node, node_to: Node, G: Graph) -> bool:
    Q = Queue()
    V = set()
    Q.put(node_from)
    V.add(node_from)
    while not Q.empty():
        node_t = Q.get_nowait()
        if node_t == node_to:
            return True

        for node_u in G.get_adjacent_nodes(node_t):
            edge = G.get_edge(node_t, node_u)
            node_c = traverseSemiDirected(node_t, edge)

            if node_c is None:
                continue

            if V.__contains__(node_c):
                continue

            V.add(node_c)
            Q.put(node_c)

    return False

def traverseSemiDirected(node: Node, edge: Edge) -> Node | None:
    if node == edge.get_node1():
        if edge.get_endpoint1() == Endpoint.TAIL or edge.get_endpoint1() == Endpoint.CIRCLE:
            return edge.get_node2()
    elif node == edge.get_node2():
        if edge.get_endpoint2() == Endpoint.TAIL or edge.get_endpoint2() == Endpoint.CIRCLE:
            return edge.get_node1()
    return None

def rule8(graph: Graph, nodes: List[Node], 
          ambiguous_triple=[Tuple]):
    """
    [R8] If A ---> B ---> C or A ---o B ---> C and A o--> C then orient tail A ---> C

    :return: True if the graph was modified; otherwise False
    """
    nodes = graph.get_nodes()
    changeFlag = False
    for node_B in nodes:
        adj = graph.get_adjacent_nodes(node_B)
        if len(adj) < 2:
            continue

        cg = ChoiceGenerator(len(adj), 2)
        combination = cg.next()

        while combination is not None:
            node_A = adj[combination[0]]
            node_C = adj[combination[1]]
            combination = cg.next()

            if (node_A, node_B, node_C) in  ambiguous_triple or (node_C, node_B, node_A) in ambiguous_triple:
                continue

            # we want either 1.) A->B->C and A o->C  or 2.) A-oB->C  and Ao->C 
            if(graph.get_endpoint(node_A, node_B) == Endpoint.ARROW and graph.get_endpoint(node_B, node_A) == Endpoint.TAIL and \
                graph.get_endpoint(node_B, node_C) == Endpoint.ARROW and graph.get_endpoint(node_C, node_B) == Endpoint.TAIL and \
                    graph.is_adjacent_to(node_A, node_C) and \
                        graph.get_endpoint(node_A, node_C) == Endpoint.ARROW and graph.get_endpoint(node_C, node_A)== Endpoint.CIRCLE) or \
                        (graph.get_endpoint(node_A, node_B) == Endpoint.CIRCLE and graph.get_endpoint(node_B, node_A) == Endpoint.TAIL and \
                graph.get_endpoint(node_B, node_C) == Endpoint.ARROW and graph.get_endpoint(node_C, node_B) == Endpoint.TAIL and \
                    graph.is_adjacent_to(node_A, node_C) and \
                        graph.get_endpoint(node_A, node_C) == Endpoint.ARROW and graph.get_endpoint(node_C, node_A)== Endpoint.CIRCLE):
                edge1 = graph.get_edge(node_A, node_C)
                graph.remove_edge(edge1)
                graph.add_edge(Edge(node_A, node_C,Endpoint.TAIL, Endpoint.ARROW))
                changeFlag = True

    return changeFlag

def rule9(graph: Graph, nodes: List[Node]):
    """
    [R9] If A o--> C and there is a possibly directed uncovered path <A, B, ..., D, C>, B and C are not connected
            then orient tail A ---> C.

    :return: True if the graph was modified; otherwise False
    """

    changeFlag = False
    nodes = graph.get_nodes()
    for node_C in nodes:
        intoCArrows = graph.get_nodes_into(node_C, Endpoint.ARROW)
        for node_A in intoCArrows:
            # we want A o--> C
            if not graph.get_endpoint(node_C, node_A) == Endpoint.CIRCLE:
                continue

            # look for a possibly directed uncovered path s.t. B and C are not connected (for the given A o--> C
            a_node_idx = graph.node_map[node_A]
            c_node_idx = graph.node_map[node_C]
            a_adj_nodes = graph.get_adjacent_nodes(node_A)
            nodes_set = [node for node in a_adj_nodes if graph.node_map[node] != a_node_idx and graph.node_map[node]!= c_node_idx]
            possible_children = find_possible_children(graph, node_A, nodes_set)
            for node_B in possible_children:
                if graph.is_adjacent_to(node_B, node_C):
                    continue
                if existsSemidirectedPath(node_from=node_B, node_to=node_C, G=graph):
                    edge1 = graph.get_edge(node_A, node_C)
                    graph.remove_edge(edge1)
                    graph.add_edge(Edge(node_A, node_C, Endpoint.TAIL, Endpoint.ARROW))
                    changeFlag = True
                    break #once we found it, break out since we have already oriented Ao->C to A->C, we want to find the next A 
    return changeFlag

def rule10(graph: Graph):
    """
    [R10] If A o--> C and B ---> C <---D, and if there are two possibly directed uncovered paths
    <A, E, ..., B>, <A, F, ..., D> s.t. E, F are disconnected, and any of these paths can be a single-edge path,
    A o--> B or A o--> D, then orient tail A ---> C.

    :return: True if the graph was modified; otherwise False
    """
    changeFlag = False
    nodes = graph.get_nodes()
    for node_C in nodes:
        intoCArrows = graph.get_nodes_into(node_C, Endpoint.ARROW)
        if len(intoCArrows) < 2:
                continue
        # get all A where A o-> C
        Anodes = [node_A for node_A in intoCArrows if graph.get_endpoint(node_C, node_A) == Endpoint.CIRCLE]
        if len(Anodes) == 0:
            continue

        for node_A in Anodes:
            A_adj_nodes = graph.get_adjacent_nodes(node_A)
            en_nodes = [i for i in A_adj_nodes if i is not node_C]
            A_possible_children = find_possible_children(graph, parent_node=node_A, en_nodes=en_nodes)
            if len(A_possible_children) < 2:
                continue

            gen = ChoiceGenerator(len(intoCArrows), 2)
            choice = gen.next()
            while choice is not None:
                node_B = intoCArrows[choice[0]]
                node_D = intoCArrows[choice[1]]

                choice = gen.next()
                # we want B->C<-D 
                if graph.get_endpoint(node_C, node_B) != Endpoint.TAIL:
                    continue

                if graph.get_endpoint(node_C, node_D) != Endpoint.TAIL:
                    continue

                for children in combinations(A_possible_children, 2):
                    child_one, child_two = children
                    if not existsSemidirectedPath(node_from=child_one, node_to=node_B, G=graph) or \
                        not existsSemidirectedPath(node_from=child_two, node_to=node_D, G=graph):
                        continue

                    if not graph.is_adjacent_to(child_one, child_two):
                        edge1 = graph.get_edge(node_A, node_C)
                        graph.remove_edge(edge1)
                        graph.add_edge(Edge(node_A, node_C, Endpoint.TAIL, Endpoint.ARROW))
                        changeFlag = True
                        break #once we found it, break out since we have already oriented Ao->C to A->C, we want to find the next A 

    return changeFlag
kunwuz commented 7 months ago

Thanks so much! Would you like to incorporate it into a PR, so that these functions can be merged into causal-learn after tests? Of course, these functions are already super helpful!

kenneth-lee-ch commented 7 months ago

Thanks so much! Would you like to incorporate it into a PR, so that these functions can be merged into causal-learn after tests? Of course, these functions are already super helpful!

Sure thing.

tingyushi commented 6 months ago

Hello @kenneth-lee-ch @kunwuz ,

Does the current implementation include rules 5, 6, 7? Also, could somebody please merge the corresponding pull request?

Thanks!!

kunwuz commented 5 months ago

Hi, these rules were not included in the implementation. The pull request by @kenneth-lee-ch has just been merged. It would be great if you are interested in incorporating rules 5,6,7 :)