py-why / causal-learn

Causal Discovery in Python. It also includes (conditional) independence tests and score functions.
https://causal-learn.readthedocs.io/en/latest/
MIT License
1.19k stars 196 forks source link

GeneralGraph.subgraph bug #117

Closed JanMarcoRuizdeVargas closed 1 year ago

JanMarcoRuizdeVargas commented 1 year ago

Hi, thanks for the great work on the package. I think I found a bug in GeneralGraph.subgraph() (causallearn.graph.GeneralGraph) when building on top of the method. My code:

from causallearn.graph.GeneralGraph import GeneralGraph
import numpy as np
_ , relevant_nodes = cdag.get_parents_plus(cluster3) # A list of nodes (node objects)
#cdag.cg.G.subgraph(relevant_nodes)
subgraph = GeneralGraph(relevant_nodes)
graph = cdag.cg.G.graph # ndarray
for i in range(cdag.cg.G.num_vars):
    print(i)
    if (not cdag.cg.G.nodes[i] in relevant_nodes):
        print(cdag.cg.G.nodes[i].get_name())
        graph = np.delete(graph, i, axis = 0)

Throws error: index 8 is out of bound for axis 0 with size 8

My code is specific to my environment, but logically works the same as

import numpy as np
array = np.zeros((5,5))
for i in range(5):
    for j in range(5):
        array[i,j] = i+j
delete = [1,2,4]
for i in range(5):
    if i in delete:
        array = np.delete(array, i, axis=0)

In causallearn, the graph is a ndarray, and iteratively deletes rows/columns. This causes an index out of bounds error, as the array gets smaller and so an index later on in the loop can be out of bounds.

Interestingly, when i directly restrict from the node list of the graph, i don't get an error:

from causallearn.graph.GraphClass import CausalGraph
test = CausalGraph(no_of_var=5, node_names=['X1','X2','X3','X4','X5'])
node_list = test.G.get_nodes()
restricted_nodes = node_list[0:2] + node_list[3:5]
subgraph = test.G.subgraph(restricted_nodes)

Am i missing something or is this bugged?

A fix (which I submit as a pull request (https://github.com/py-why/causal-learn/pull/118) also) would be to change the code to:

def subgraph(self, nodes: List[Node]):
    subgraph = GeneralGraph(nodes)

    graph = self.graph

    nodes_to_delete = []

    for i in range(self.num_vars):
        if not (self.nodes[i] in nodes):
            nodes_to_delete .append(i)

    graph = np.delete(graph, nodes_to_delete, axis = 0)
    graph = np.delete(graph, nodes_to_delete, axis = 1)

    subgraph.graph = graph
    subgraph.reconstitute_dpath(subgraph.get_graph_edges())

    return subgraph

Let me know what you think. Best, Jan Marco

JanMarcoRuizdeVargas commented 1 year ago

Resolved, fix was merged into master