y0-causal-inference / y0

❓y0 (pronounced "why not?") is for causal inference in Python
https://y0.readthedocs.io
BSD 3-Clause "New" or "Revised" License
44 stars 10 forks source link

set_latent marks all nodes as non-latent #190

Closed pnavada closed 2 months ago

pnavada commented 1 year ago

Below is the code snipped for set_latent

def set_latent(
    graph: nx.DiGraph,
    latent_nodes: Union[Variable, Iterable[Variable]],
    tag: Optional[str] = None,
) -> None:
    """Quickly set the latent variables in a graph."""
    if tag is None:
        tag = DEFAULT_TAG
    if isinstance(latent_nodes, Variable):
        latent_nodes = [latent_nodes]

    latent_nodes = set(latent_nodes)
    for node, data in graph.nodes(data=True):
        data[tag] = node in latent_nodes

Lets define an input graph

from y0.graph import NxMixedGraph

graph1 = NxMixedGraph.from_str_adj(
    directed = {
        "A": ["B"]
    },
    undirected={
        "B": ["C"]
    }
)

Now, lets convert this to a latent variable dag

lv_dag = graph1.to_latent_variable_dag()

Now lets see the node info

for node, data in lv_dag.nodes.items():
  print(node, data)

Output is below

B {'hidden': False}
C {'hidden': False}
A {'hidden': False}
u_0 {'hidden': True}

Now, lets mark some additional nodes as latent

set_latent(lv_dag, [Variable("B")])

Now, lets check the the node info again

for node, data in lv_dag.nodes.items():
  print(node, data)

Output is below

B {'hidden': False}
C {'hidden': False}
A {'hidden': False}
u_0 {'hidden': False}

The expectation is for both u_0 and B to be latent.

The issue in the current implementation is that it unmarks the existing latent nodes. Additionally, because the node in the for loop is of type 'str', the variables passed are also not marked as latent.

Possible fix:

def set_latent(
    graph: nx.DiGraph,
    latent_nodes: Union[Variable, Iterable[Variable]],
    tag: Optional[str] = None,
) -> None:
    """Quickly set the latent variables in a graph."""
    if tag is None:
        tag = DEFAULT_TAG
    if isinstance(latent_nodes, Variable):
        latent_nodes = [latent_nodes]

    latent_nodes = set(latent_nodes)
    for node, data in graph.nodes(data=True):
        if Variable(node) in latent_nodes:
            data[tag] = True
cthoyt commented 1 year ago

This implementation expects you to know the entire definition of all nodes in the graph ahead of time. I will make an update to give a function that does what you want