y0-causal-inference / y0

❓y0 (pronounced "why not?") is for causal inference in Python
https://y0.readthedocs.io
BSD 3-Clause "New" or "Revised" License
44 stars 10 forks source link

Learnable backdoor with SymPyTorch #217

Open cthoyt opened 7 months ago

cthoyt commented 7 months ago

This takes some code out of #210 for later:

def get_single_door_learnable(
    graph: NxMixedGraph, data: pd.DataFrame
) -> dict[tuple[Variable, Variable], float]:
    """Estimate parameter values for a linear SCM using backdoor adjustment."""
    inference = graph.to_pgmpy_causal_inference()
    rv = {}
    for source, target in graph.directed.edges():
        try:
            adjustment_sets = inference.get_all_backdoor_adjustment_sets(source.name, target.name)
        except ValueError:
            continue
        if not adjustment_sets:
            continue

        # 2 ways - learnable, or specify a prior. Interpret lower and upper
        # bound as range for learnable paramter OR as a prior

        adjustment_set = list(adjustment_sets)[0]
        variables = sorted(adjustment_set | {source.name})
        idx = variables.index(source.name)
        model = LinearRegression()
        model.fit(data[variables], data[target.name])
        rv[source, target] = model.coef_[idx]
    return rv