def set_latent(
graph: nx.DiGraph,
latent_nodes: Union[Variable, Iterable[Variable]],
tag: Optional[str] = None,
) -> None:
"""Quickly set the latent variables in a graph."""
if tag is None:
tag = DEFAULT_TAG
if isinstance(latent_nodes, Variable):
latent_nodes = [latent_nodes]
latent_nodes = set(latent_nodes)
for node, data in graph.nodes(data=True):
data[tag] = node in latent_nodes
for node, data in lv_dag.nodes.items():
print(node, data)
Output is below
B {'hidden': False}
C {'hidden': False}
A {'hidden': False}
u_0 {'hidden': True}
Now, lets mark some additional nodes as latent
set_latent(lv_dag, [Variable("B")])
Now, lets check the the node info again
for node, data in lv_dag.nodes.items():
print(node, data)
Output is below
B {'hidden': False}
C {'hidden': False}
A {'hidden': False}
u_0 {'hidden': False}
The expectation is for both u_0 and B to be latent.
The issue in the current implementation is that it unmarks the existing latent nodes. Additionally, because the node in the for loop is of type 'str', the variables passed are also not marked as latent.
Possible fix:
def set_latent(
graph: nx.DiGraph,
latent_nodes: Union[Variable, Iterable[Variable]],
tag: Optional[str] = None,
) -> None:
"""Quickly set the latent variables in a graph."""
if tag is None:
tag = DEFAULT_TAG
if isinstance(latent_nodes, Variable):
latent_nodes = [latent_nodes]
latent_nodes = set(latent_nodes)
for node, data in graph.nodes(data=True):
if Variable(node) in latent_nodes:
data[tag] = True
This implementation expects you to know the entire definition of all nodes in the graph ahead of time. I will make an update to give a function that does what you want
Below is the code snipped for set_latent
Lets define an input graph
Now, lets convert this to a latent variable dag
Now lets see the node info
Output is below
Now, lets mark some additional nodes as latent
Now, lets check the the node info again
Output is below
The expectation is for both u_0 and B to be latent.
The issue in the current implementation is that it unmarks the existing latent nodes. Additionally, because the node in the for loop is of type 'str', the variables passed are also not marked as latent.
Possible fix: