Open brash6 opened 4 months ago
Today, overall med_bench
structure is like this :
└── src
├── med_bench
│ ├── get_estimation.py
│ ├── get_simulated_data.py
│ ├── mediation.py
│ └── utils
│ ├── constants.py
│ ├── nuisances.py
│ └── utils.py
└── tests
├── estimation
│ ├── generate_tests_results.py
│ ├── test_exact_estimation.py
│ ├── test_get_estimation.py
│ └── tests_results.npy
└── simulate_data
└── test_get_simulated_data.py
Following the work started by @houssamzenati on modularizing estimation functions, I suggest we refactor the package as follow :
└── src
├── med_bench
│ ├── get_estimation.py
│ ├── get_simulated_data.py
│ ├── estimation
│ │ ├── base.py
│ │ ├── dml.py
│ │ ├── g_computation.py
│ │ ├── ipw.py
│ │ ├── linear.py
│ │ ├── mr.py
│ │ └── tmle.py
│ ├── nuisances
│ │ ├── conditional_outcome.py
│ │ ├── cross_conditional_outcome.py
│ │ ├── density.py (to be removed ?)
│ │ ├── kme.py (to be removed ?)
│ │ ├── propensities.py
│ │ └── utils.py
│ └── utils
│ ├── config.py
│ ├── decorators.py
│ ├── kernel.py (to be removed ?)
│ ├── loader.py
│ ├── parse.py
│ ├── plots.py
│ ├── scores.py
│ └── utils.py
└── tests
├── estimation
│ ├── generate_tests_results.py
│ ├── base.py
│ ├── test_dml.py
│ ├── test_g_computation.py
│ ├── test_ipw.py
│ ├── test_linear.py
│ ├── test_mr.py
│ ├── test_tmle.py
│ ├── test_utils.py
│ └── tests_results.npy
└── simulate_data
└── test_get_simulated_data.py
With the following properties (tbd together more deeply):
estimation/base.py
: A file containing the parent class for each estimator. This class has several abstract and concrete classes. Tbd : This class could be designed to be inherited from a scikit-learn base model such that any scikit-learn Classifier or Regressor could be used in each method ?estimation/dml.py
, estimation/g_computation.py
and every other files in the estimation folder contain classes inherited from parent class in base.py
and define concrete functions for each estimation methodnuisances/conditional_outcome.py
and every other files host nuisance parameters computation methods. Tbd : Some files may be removed because not related to med_bench
but to @houssamzenati experiments ?nuisances/utils.py
: a file containing some utils functions for nuisances parameters computation. Tbd : We could move the content of this file to global utils folder ?tests/estimation/base.py
: I suggest we create a base
file to design a testing structure for each estimator, this could include a framework for exactness tests and estimation tests that matches the modularized structure. tests/estimation/test_.py
: a test file for each estimation method including exactness and estimation tests. tests/estimation/generate_tests_results.py
: we keep the same logic for exactness tests but the new data generation must match the format requirements for the new test designget_estimation.py
: a file for launching estimations, this file enables the user to specify every parameters directly from the command lineget_simulated_data.py
: the file for data generation, should be pretty similar as the existing oneRegarding the implementation, I suggest two options :
develop
branch to ensure that everything's working for each new feature before deploying to main
branch when everything's implemented.These overall suggestions are open for discussion, and I would be happy to have your inputs on this
FYI : @judithabk6 @bthirion @houssamzenati
Thx for the suggestions. I think that this clarifies the design. A few comments:
get_estimation.py
. I'd rather have all estimation-related stuff in the estimation sub-module, and would import the function at the module level:
from med_bench.estimation import get_estimation
but we could also make get_estimation importable at the upper level if you wish.Thanks for the great push @brash6 here are a few remarks on my side:
As for the specifics of the submodules:
Implementation remarks: I think your suggestions on the branches are great. Thanks for the propositions.
We have taken some time last thursday to discuss the signature of estimators, that should be a starting point to think about the refactor. Here is a beginning, but we did not reach a consensus, but we can start the discussion back from this and iterate
we would have a wrapper that then calls the right function
mediated_effects(method="tmle",
mediator_type="binary",
nuisance_mu=SklearnRegressor(),
nuisance_prop=SklearnClassifier(),
nuisance_prop_mediator=SklearnClassifier(),
nuisance_density_mediator=kme())
with nuisances cross_conditional_mean_outcome $E[Y(t, M(t'))=E[E[Y|T=t, M=m, X=x]|T=t']$ propensity $P(T=1|X=x)$ propensity_mediator $P(T=1|X=x, M=m)$
mediated_effects
and the get_estimation
function, to have a more sklearn-like patternmediator_type
is an argument, with a default "auto" that we do not use in examples (but in experiments, we may want to force to treat a binary mediator as continuous)List of wanted estimators
estimator | binary | continuous | multidim | comment | ||
---|---|---|---|---|---|---|
CoefficientProduct | x | x | x | linear regression for each dimension of M and for Y (authorize penalisation): maybe one regression object of M and one for Y | ||
GComputation | x | x | x | binary/discret: classifier for M (multiclass if discrete), regression for y, continuous/multidim: implicit cross condition like _estimate_cross_conditional_mean_outcome_nesting 1 classifier and 1 regressor fitted separately for T=1 and T = 0 |
||
IPW | x | x | x | _estimate_treatment_probabilities to split in 2, for P(T=1 |
X) and P(T=1 | X,M) - 1 classifiers, instanciated twice |
MultiplyRobust | x | x | x | binary/discret: classifier for M (multiclass if discrete), _estimate_cross_conditional_mean_outcome for y, continuous/multidim: implicit cross condition like _estimate_cross_conditional_mean_outcome_nesting 1 classifier and 1 regressor fitted separately for T=1 and T = 0 |
||
DML | x | x | x | binary/discret: classifier for M (multiclass if discrete), _estimate_cross_conditional_mean_outcome for y, continuous/multidim: implicit cross condition like _estimate_cross_conditional_mean_outcome_nesting 1 classifier and 1 regressor fitted separately for T=1 and T = 0 |
||
TMLE | x | x | x | default ratio of propensity scores, treatment_probabilities, _estimate_cross_conditional_mean_outcome_nesting |
train_data, test_data = split_train_test(data)
MultiplyRobust(mediator_type="binary",
nuisance_mu=SklearnRegressor(),
nuisance_prop=SklearnClassifier(),
nuisance_prop_mediator=SklearnClassifier())\
.fit(train_data.X, train_data.t, train_data.m, train_data.y)\
.estimate(test_data.X, test_data.t, test_data.m, test_data.y)
does not work with DML with cross-fitting
DML(...., cross_fitting=False)
.fit(train_data.X, train_data.t, train_data.m, train_data.y)\
.estimate(test_data.X, test_data.t, test_data.m, test_data.y)
DML(...., cross_fitting=True).fit_estimate(data)
or method to implement cross-fitting with other estimators
cf_iterator = KFold(k=5)
for data_train, data_test in cf_iterator:
result.append(DML(...., cross_fitting=False)
.fit(train_data.X, train_data.t, train_data.m, train_data.y)\
.estimate(test_data.X, test_data.t, test_data.m, test_data.y))
np.mean(result)
data would be a named_tuple or a dictionary
@bthirion do you remember what we chose to do for cross-fitting and model selection?
for now maybe let's go for the third solution, and we will add a fit_estimate
method if we need later? in that case, we should return a NotImplementedError
and explain that to do something similar to crossfit the user should do
cf_iterator = KFold(k=5)
for data_train, data_test in cf_iterator:
result.append(DML(...., cross_fitting=False)
.fit(train_data.X, train_data.t, train_data.m, train_data.y)\
.estimate(test_data.X, test_data.t, test_data.m, test_data.y))
np.mean(result)
@brash6
To clarify: So, you mean that we do not support built-in cross fitting for now, because there is not enough evidence that it's needed ? That sounds resonable. We need to adjust our ambitions to the time we have.
yes, and we should probably add it back rather soon, but right now I am struggling to have a global view of everything that is going on
The goal of this issue is to describe the future global refactoring of the package. The aim of this refactoring is to :