Open kklein opened 3 months ago
Currently we are training
n_folds
many $\hat{\mu}_0$ modelsn_folds
many $\hat{\mu}_k$ models for every $k$n_folds
many $\hat{\tau}_{0,k}$ models for every $k$n_folds
many $\hat{\tau}_{k,0}$ models for every $k$In order to answer an in-sample query of
Give me models $\hat{\tau}_{0,k}$ and $\hat{\tau}_{k,0}$ which have seen no information about sample $i$ at all
We could train
n_folds
* n_folds
many $\hat{\tau}_{0,k}$ models for every $k$n_folds
* n_folds
many $\hat{\tau}_{k,0}$ models for every $k$In the scenario described in the issue, we would then run the predict
method as such:
n_folds * (n_folds - 1)
many models $\hat{\tau}_0(X_i)$ which are based on $\hat{\mu}_k$ models, which have not seen data point $i$; these model estimates can be aggregatedn_folds * (n_folds - 1)
many model $\hat{\tau}_1(X_i)$ which have not seen $i$ and have used any $\hat{\mu}_0$ models; these model estimates can be aggregatedImportantly, this would
is_oos
Issue at hand
@ArseniyZvyagintsevQC brought the following to our attention:
Let us assume a binary treatment variant scenario in which we want to work with in-sample predictions, i.e.
is_oos=False
.The current implementation would go about fitting five models, three of which considered nuisance models and two of which considered treatment models:
"treatment_variant"
"treatment_variant"
"propensity_model"
"control_effect_model"
"treatment_effect_model"
More background on this here.
Note that each of these models is cross-fitted. More precisely, each is cross-fitted wrt the data it has seen at training time.
Let's suppose now that we are at inference time and encounter an in-sample data point $i$. Wlog, let's assume that $W_i=1$. In order to come up with a CATE estimate, the
predict
method will runis_oos=True
since this datapoint has not been seen during training time of the model $\hat{\tau}_0$is_oos=False
since this datapoint has indeed been seen during the training time of the model $\hat{\tau}_1$The latter call makes sure we avoid leakage in $\hat{\tau}_1$. The former call, however, does not completely avoid leakage: even though $i$ hasn't been seen in the training of $\hat{\tau}_0$, it has been seen in $\hat{\mu}_1$, which is, in turn, used by $\hat{\tau}_0$. Therefore, the observed outcome $Y_i$ can leak into the estimate $\hat{\tau}(X_i)$.
Next steps
We can devise an extreme, naïve approach to counteract this issue by training every type of model once per datapoint. Clearly, this ensures the absence of data leakage. The challenge with this issue revolves around coming up with a design that