april-tools / cirkit

a python framework to build, learn and reason about probabilistic circuits and tensor networks
https://cirkit-docs.readthedocs.io/en/latest/
GNU General Public License v3.0
71 stars 1 forks source link

Gaussian Chow-Liu Algorithm and Linear region graph #150

Closed loreloc closed 1 week ago

loreloc commented 11 months ago

I am adding below the code (and tests) to learn a RG corresponding to an HCLT structure over continuous observed variables, as well as the code to construct a "linear" RG. Please add some documentation and check it is working.

cc @gengala

Note that the root of the tree is found by choosing the node that minimizes the median distance over all the other nodes. Are there other heuristics?

import warnings
from typing import Optional, List

import numpy as np
from scipy import sparse

from region_graph import RegionGraph, RegionNode, PartitionNode

def ChowLiuGaussian(data: np.ndarray) -> RegionGraph:
    ccoeff = np.corrcoef(data, rowvar=False)
    num_variables = data.shape[1]
    with np.errstate(divide='ignore'):
        mutual_info = -0.5 * np.log(1.0 - np.square(ccoeff))
        mutual_info[np.isinf(mutual_info)] = 0.0
    mst = sparse.csgraph.minimum_spanning_tree(-mutual_info, overwrite=True)
    dists = sparse.csgraph.floyd_warshall(mst, directed=False, unweighted=True)
    root = np.argmin(np.median(dists, axis=0))
    bfs, tree = sparse.csgraph.breadth_first_order(
        mst, directed=False, i_start=root, return_predecessors=True
    )
    tree[root] = -1
    tree: List[int] = tree.tolist()  # List of predecessors, next build the RG data structure

    rg = RegionGraph()
    partitions: List[Optional[PartitionNode]] = [None] * num_variables
    for v in range(num_variables):
        cur_v, prev_v = v, tree[v]
        while prev_v != -1:
            if partitions[prev_v] is None:
                p_scope = {v, prev_v}
                partition_node = PartitionNode(p_scope)
                partitions[prev_v] = partition_node
            else:
                p_scope = set(partitions[prev_v].scope)
                p_scope = {v} | p_scope
                partition_node = PartitionNode(p_scope)
                partitions[prev_v] = partition_node
            cur_v, prev_v = prev_v, tree[cur_v]

    regions: List[Optional[RegionNode]] = [None] * num_variables
    for cur_v in range(num_variables):
        prev_v = tree[cur_v]
        leaf_region = RegionNode({cur_v})
        if partitions[cur_v] is None:
            if prev_v != -1:
                rg.add_edge(leaf_region, partitions[prev_v])
            regions[cur_v] = leaf_region
        else:
            rg.add_edge(leaf_region, partitions[cur_v])
            p_scope = partitions[cur_v].scope
            if regions[cur_v] is None:
                regions[cur_v] = RegionNode(set(p_scope))
            rg.add_edge(partitions[cur_v], regions[cur_v])
            if prev_v != -1:
                rg.add_edge(regions[cur_v], partitions[prev_v])

    return rg

# TESTS

import numpy as np
import pytest

from region_graph.chow_liu import ChowLiuGaussian
from tests.region_graph.test_region_graph import (
    check_region_graph_save_load,
    check_region_partition_layers,
)

@pytest.mark.parametrize(
    "num_variables", [2, 5, 12]
)
def test_chow_liu_gaussian(num_variables: int) -> None:
    data = np.random.randn(10_000, num_variables)
    for j in range(data.shape[1]):
        data[:, j] = data[:, j - 1] * np.sin(data[:, j]) - np.square(data[:, j - 2])
    rg = ChowLiuGaussian(data)
    assert rg.num_variables == num_variables
    assert rg.is_smooth
    assert rg.is_decomposable
    assert rg.is_structured_decomposable
    check_region_partition_layers(rg, bottom_up=True)
    check_region_partition_layers(rg, bottom_up=False)
    check_region_graph_save_load(rg)
import numpy as np

from region_graph import RegionGraph, RegionNode, PartitionNode

def LinearVTree(num_variables: int, num_repetitions: int = 1, randomize: bool = False, seed: int = 42) -> RegionGraph:
    root = RegionNode(range(num_variables))
    rg = RegionGraph()
    rg.add_node(root)
    random_state = np.random.RandomState(seed)

    for _ in range(num_repetitions):
        parent_node = root
        vars = list(range(num_variables))
        if randomize:
            random_state.shuffle(vars)
        for i, v in enumerate(vars[:-1]):
            partition_node = PartitionNode(set(parent_node.scope))
            rg.add_edge(partition_node, parent_node)
            leaf_node = RegionNode({v})
            if i == num_variables - 2:
                rest_node = RegionNode({vars[-1]})
            else:
                rest_node = RegionNode({j for j in vars[i + 1:]})
            rg.add_edge(leaf_node, partition_node)
            rg.add_edge(rest_node, partition_node)
            parent_node = rest_node

    return rg

# TESTS

import itertools

import pytest

from region_graph.linear_vtree import LinearVTree
from tests.region_graph.test_region_graph import (
    check_region_graph_save_load,
    check_region_partition_layers,
)

@pytest.mark.parametrize(
    "num_variables,num_repetitions,randomize", itertools.product([1, 2, 5, 12], [1, 3], [False, True])
)
def test_linear_vtree(num_variables: int, num_repetitions: int, randomize: bool) -> None:
    rg = LinearVTree(num_variables, num_repetitions=num_repetitions, randomize=randomize)
    assert rg.num_variables == num_variables
    assert rg.is_smooth
    assert rg.is_decomposable
    if num_repetitions == 1 or not randomize:
        assert rg.is_structured_decomposable
    check_region_partition_layers(rg, bottom_up=True)
    check_region_partition_layers(rg, bottom_up=False)
    check_region_graph_save_load(rg)
arranger1044 commented 11 months ago

thanks!

let's be consistent with nomenclature and substitute vtree with region_graph.

gengala commented 11 months ago

Maybe this function runs faster than floyd_warshall.

I used it for my own structures here.

gengala commented 11 months ago

I would also go for a chunked version of np.corrcoef, to avoid OOM.

arranger1044 commented 11 months ago

nice, give it a try @gengala !

gengala commented 11 months ago

I'll try asap, quite busy these days :/

ah, if you need/want to double-check that your trees are correct you can check this file. Anyway, isn't it better to have a function tree2region_graph ?

loreloc commented 11 months ago

ah, if you need/want to double-check that your trees are correct you can check this file. Anyway, isn't it better to have a function tree2region_graph ?

Yes indeed. Thank you!

loreloc commented 1 month ago

There is an implementation in the clt branch by @gengala . Some effort might be required to merge it though.