DataCanvasIO / YLearn

YLearn, a pun of "learn why", is a python package for causal inference
https://ylearn.readthedocs.io
Apache License 2.0
391 stars 75 forks source link

a problem return 'CausalTree' object has no attribute 'criterion' in causal_tree.py about def plot_causal_tree #42

Closed guangwufengqi closed 1 year ago

guangwufengqi commented 1 year ago

Please make sure that this is a bug. yes

System information

Describe the current behavior I run these code ,and make a causal tree successfully,but when i want plot a causal tree by plot_causal_tree,it return a error: 'CausalTree' object has no attribute 'criterion'

Describe the expected behavior

Standalone code to reproduce the issue import numpy as np import matplotlib.pyplot as plt

from ylearn.estimator_model.causal_tree import CausalTree ## 本次使用的估计器是因果树 from ylearn.exp_dataset.exp_data import sq_data from ylearn.utils._common import to_df ## ylearn 提供的很方便的数据框转换函数

数据准备

n = 2000 d = 10 n_x = 1 y, x, v = sq_data(n, d, n_x) y ## 2000x1 len(x) ## 2000x1 v.shape ## 2000x10

真实的干预效应

true_te = lambda X: np.hstack([X[:, [0]]**2 + 1, np.ones((X.shape[0], n_x - 1))]) data = to_df(treatment=x, outcome=y, v=v) data ## 最终的数据集 n = 2000 d = 10 n_x = 1 y, x, v = sq_data(n, d, n_x)

构造测试数据

v_test = v[:min(100, n)].copy() v_test[:, 0] = np.linspace(np.percentile(v[:, 0], 1), np.percentile(v[:, 0], 99), min(100, n))#产生等差数列 test_data = to_df(v=v_test) test_data ## 100x10 测试集的第一列用了原数据里边第一列的从小到大的等差数列,暂不明白目的在哪里

因果树建模并拟合

outcome = 'outcome' treatment = 'treatment' adjustment = data.columns[2:] ct = CausalTree(min_samples_leaf=3, max_depth=5) ct.fit(data=data, outcome='outcome', treatment='treatment', adjustment=adjustment) ct_pred = ct.estimate(data=test_data) ## 基于100行的测试数据进行预测 ct_pred

ct.plot_causal_tree(max_depth=5, feature_names=adjustment)# this is the line retruning error,please help tell me how to solve it

Are you willing to submit PR?(Yes/No) yes

Other info / logs Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.

BochenLv commented 1 year ago

Hi guangwufengqi,

Thanks for opening this issue. The reported bug is due to that CausalTree does not explicitly include "criterion" as one of its property.

To solve it, you can add the following code before calling the plot_causal_tree method of an instance of the CausalTree:

ct.criterion = "HonestMSE" 

Furthermore, we will fix this bug in the next version including various forest based estimator models. It is also recommended to try those new models to get better performance then.

guangwufengqi commented 1 year ago

thanks for your reply,this is userful! and ,i find another way by add one line code in packages file "causal_tree.py", self.criterion = criterion i write it at the bottom in the def: _fit_with_array,it's also worked~

BochenLv commented 1 year ago

Yes, this will also work 😄