DataCanvasIO / YLearn

YLearn, a pun of "learn why", is a python package for causal inference
https://ylearn.readthedocs.io
Apache License 2.0
389 stars 75 forks source link

Policy Optimization API Usage #56

Open zhj2022 opened 5 months ago

zhj2022 commented 5 months ago

Hi, I'm trying to do policy optimization using YLearn. I have read the docs about this but didn't understand the meaning very well. Formally, a policy optimization problem can be written as: $x^{*}=\text{argmax}_x\mathbb{E}[\mathcal{Y}|\text{do}(\mathcal{X}=x), \mathcal{S}]$. Then how do $\mathcal{Y}$, $\mathcal{X}$ and $\mathcal{S}$ represented in the arguments of the est.fit() api in https://ylearn.readthedocs.io/en/latest/sub/policy.html respectively? I need a more concrete explanation to better use the given api, thanks!

BochenLv commented 5 months ago

Hi,

Thanks for the question. The goal of policy optimization is to find the specific treatment that maximizes the causal effect among all possible treatment values, thus $\mathcal{Y}$ should be the causal effects and $\mathcal{X}$ should include possible treatments.

As for the fit() function, first you need to provide a data, then:

zhj2022 commented 5 months ago

Hi,

Thanks for the question. The goal of policy optimization is to find the specific treatment that maximizes the causal effect among all possible treatment values, thus Y should be the causal effects and X should include possible treatments.

As for the fit() function, first you need to provide a data, then:

* Y can be given in 3 different ways by specifying `effect` (string, name of the causal effects in your data), or `effect_array` (an additional array of causal effect) if your data does not include causal effects directly, or, if you don't have the calculated causal effects, providing an `est_model` (must be trained), which will then be used to calculate the causal effects.

* X needs not be given since the information has already included in the causal effects when specifying Y.

* S should include all the other relevant variables in your data and is given by `covariate` (names of these relevant variables in your data)

Thanks for your reply!

In the example given in the doc, the array $y$, which is taken by the argument effect_array, has two columns. Does the first column stands for $\mathcal{X}$ while the second column stands for $\mathcal{Y}$? And if effect_array takes an array which contains more than two columns, how does $\mathcal{X}$ and $\mathcal{Y}$ represented respectively? And I'm also confused about the meaning of the return value which is a (1000, ) array.

BochenLv commented 5 months ago

The shape of an allowed effect_array (usually a numpy array) should be $(N, J)$, where $N$ is the number of examples in your data and $J$ is the number of possible treatment values. In this way, the $(n,j)$ element of a given effect_array indicates that the causal effect of taking the $j$-th treatment on the $n$-th example is effect_array[n, j].

Specifically, for the example provided in the doc,

zhj2022 commented 5 months ago

Thanks, I can understand what it means by specifying effect_array. In the case where I don't have all the treatment effects, for example, for $i$-th example I only have treatment effect for treatment 1 while for $j$-th example I only have treatment 2, how should I specify the argument effect_array?

BochenLv commented 5 months ago

The important thing is that you need the full set of treatment effects to apply an optimization of the policy, since you are actually selecting the suitable treatment values when optimizing the policy. That said, if you only have the treatment effect for treatment 1 then there is no information on how to make the selection, thus the training of the policy tree is not feasible.

If you don't have the calculated effect_array, then you can simply use a trained est_model (any kinds of estimator_model provided by YLearn) and pass it as an argument to the fit method of the policy_tree (the fit method will automatically calculate the effect).

zhj2022 commented 5 months ago

Okay, your explanation is quite clear! Thanks!