y0-causal-inference / y0

❓y0 (pronounced "why not?") is for causal inference in Python
https://y0.readthedocs.io
BSD 3-Clause "New" or "Revised" License
45 stars 10 forks source link

Simplification of latent dags results in TypeError #218

Closed pnavada closed 6 months ago

pnavada commented 6 months ago
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[12], line 1
----> 1 new_graph = remove_nuisance_variables(graph, treatments=treatment, outcomes=outcome)
File ~\PycharmProjects\eliater\venv\Lib\site-packages\eliater\discover_latent_nodes.py:179, in remove_nuisance_variables(graph, treatments, outcomes, tag)
    172 rv = NxMixedGraph(
    173     directed=graph.directed.copy(),
    174     undirected=graph.undirected.copy(),
    175 )
    176 lv_dag = mark_nuisance_variables_as_latent(
    177     graph=rv, treatments=treatments, outcomes=outcomes, tag=tag
    178 )
--> 179 simplified_latent_dag = simplify_latent_dag(lv_dag, tag=tag)
    180 return NxMixedGraph.from_latent_variable_dag(simplified_latent_dag.graph, tag=tag)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[12], line 1
----> 1 new_graph = remove_nuisance_variables(graph, treatments=treatment, outcomes=outcome)
File ~\PycharmProjects\eliater\venv\Lib\site-packages\eliater\discover_latent_nodes.py:179, in remove_nuisance_variables(graph, treatments, outcomes, tag)
    172 rv = NxMixedGraph(
    173     directed=graph.directed.copy(),
    174     undirected=graph.undirected.copy(),
    175 )
    176 lv_dag = mark_nuisance_variables_as_latent(
    177     graph=rv, treatments=treatments, outcomes=outcomes, tag=tag
    178 )
--> 179 simplified_latent_dag = simplify_latent_dag(lv_dag, tag=tag)
    180 return NxMixedGraph.from_latent_variable_dag(simplified_latent_dag.graph, tag=tag)
File ~\PycharmProjects\eliater\venv\Lib\site-packages\y0\algorithm\simplify_latent.py:48, in simplify_latent_dag(graph, tag)
     46 _, widows = remove_widow_latents(graph, tag=tag)
     47 _, unidirectional_latents = remove_unidirectional_latents(graph, tag=tag)
---> 48 _, redundant = remove_redundant_latents(graph, tag=tag)
     50 return SimplifyResults(
     51     graph=graph,
     52     widows=widows,
     53     unidirectional_latents=unidirectional_latents,
     54     redundant=redundant,
     55 )
File ~\PycharmProjects\eliater\venv\Lib\site-packages\y0\algorithm\simplify_latent.py:188, in remove_redundant_latents(graph, tag)
    176 def remove_redundant_latents(
    177     graph: nx.DiGraph, tag: Optional[str] = None
    178 ) -> Tuple[nx.DiGraph, Set[Variable]]:
    179     """Remove redundant latent variables.
    180
    181     W is a redundant latent variable if children of W are
   (...)
    186     :returns: The graph, modified in place
    187     """
--> 188     remove = set(_iter_redundant_latents(graph, tag=tag))
    189     graph.remove_nodes_from(remove)
    190     return graph, remove
File ~\PycharmProjects\eliater\venv\Lib\site-packages\y0\algorithm\simplify_latent.py:198, in _iter_redundant_latents(graph, tag)
    194 latents: Mapping[Variable, Set[Variable]] = {
    195     node: set(graph.successors(node)) for node in iter_latents(graph, tag=tag)
    196 }
    197 for (left, left_children), (right, right_children) in itt.product(latents.items(), repeat=2):
--> 198     if left_children == right_children and left > right:
    199         # if children are the same, keep the lower sort order node
    200         yield left
    201     elif left_children < right_children:
    202         # if left's children are a proper subset of right's children, we don't need left
TypeError: '>' not supported between instances of 'str' and 'Variable' (edited) 

Possible fix Module - simplify_latent.py Function - transform_latents_with_parents

for latent_node, parents, children in iter_middle_latents(graph, tag=tag): graph.remove_node(latent_node) graph.add_edges_from(itt.product(parents, children)) new_node = Variable(f"{latent_node}{suffix}")

new_node should be of type 'str', instead its 'Variable' new_node = f"{latent_node}{suffix}" fixes the issue