Quantco / metalearners

MetaLearners for CATE estimation
https://metalearners.readthedocs.io/en/latest/
BSD 3-Clause "New" or "Revised" License
34 stars 4 forks source link

X-Learner: Use the same sample splits in all base models. #84

Open kklein opened 3 months ago

kklein commented 3 months ago

TODOs:

Observations

import numpy as np
from sklearn.base import BaseEstimator
from sklearn.model_selection import cross_validate

class Memorizer(BaseEstimator):
    def fit(self, X, y):
        self._y = y
        print(len(y))
        return self

    def score(self, X, y):
        return 0

n_samples = 100
n_folds = 4
# We define cvs such that when combining the training and test set of every 'split', we have a strict subset of 
# the dataset (X, y). 
cvs = [
    (np.array([fold_index]), np.array(fold_index + 50)) for fold_index in range(n_folds)
]
estimator = Memorizer()

X = np.random.normal(size=(n_samples, 2))
y = np.random.normal(size=n_samples)
cross_validate(
    estimator,
    X,
    y,
    cv=cvs,
)

yields the following output:

1
1
1
1

Checklist

kklein commented 3 months ago

FYI @MatthiasLoefflerQC I created a first draft of how the same splits could be used for all base learners, including treatment models. As of now the estimates are still clearly awry, e.g. an RMSE of ~13 compared to ~0.05. This happens for both in-sample and out-of-sample estimation. I currently have no real ideas on what's going wrong; will try to make some progress still

kklein commented 3 months ago

As of now the estimates are still clearly awry, e.g. an RMSE of ~13 compared to ~0.05.

The base models all seem to be doing fine wrt their individual targets at hand. Yet, when I compare pairs of treatment effect model estimates at prediction time, it become blatantly apparent that something is going wrong:

np.mean(tau_hat_control - tau_hat_treatment)
>>> 27.051119307766754
np.mean(tau_hat_control)
>>> 14.104902455634836
np.mean(tau_hat_treatment)
>>> -12.946216852131919

Update: These discrepancies have been substantially reduced by bbfff15. The RMSEs on true cates are still massive when compared to status quo.