sebp / scikit-survival

Survival analysis built on top of scikit-learn
GNU General Public License v3.0
1.12k stars 214 forks source link

SurvivalTree is handling sample_weight incorrectly #443

Closed sebp closed 3 months ago

sebp commented 6 months ago

Describe the bug

Weighting samples by passing sample_weight to SurvivalTree.fit() is not considered.

This is essential for RandomSurvivalForest to work correctly, because bootstrap samples for each tree in the ensemble are created by passing sample_weight to SurvivalTree.fit(). For instance, sample_weight=[1, 0, 2, 1] would represent a bootstrap dataset where the first and last sample appear once, the second sample is not part of the sample, and the third sample appears twice.

Code Sample to Reproduce the Bug

```python from sksurv.datasets import load_whas500 from sksurv.preprocessing import OneHotEncoder from sksurv.tree import SurvivalTree X, y = load_whas500() Xt = OneHotEncoder().fit_transform(X) n_samples = Xt.shape[0] weights = np.ones(n_samples, dtype=int) weights[:11] = np.arange(11, dtype=int) y_array = np.empty((Xt.shape[0], 2), dtype=float) y_array[:, 0] = y["lenfol"] y_array[:, 1] = y["fstat"].astype(float) X_repeat = np.repeat(Xt, weights, axis=0) y_repeat = np.repeat(y_array, weights, axis=0) # fit on the full data to create unique_times_ and is_event_time_ t0 = SurvivalTree(random_state=2).fit(Xt, y) # fit on dataset where samples have been copiied t1 = SurvivalTree(random_state=2)._fit( X_repeat, (y_repeat, t0.unique_times_, t0.is_event_time_), check_input=False ) # fit on dataset where sample_weight is used t2 = SurvivalTree(random_state=2)._fit( Xt.values, (y_array, t0.unique_times_, t0.is_event_time_), sample_weight=weights.astype(float), check_input=False ) value_1 = t1.tree_.value value_2 = t2.tree_.value # check that both trees are identical assert np.allclose(value_1, value_2) ``` The example fails, whereas using a DecisionTreeClassifier for illustration does result in identical trees. ```python from sklearn.tree import DecisionTreeClassifier t1 = DecisionTreeClassifier(random_state=2).fit( X_repeat, y_repeat[:, 1] ) t2 = DecisionTreeClassifier(random_state=2).fit( Xt, y["fstat"], sample_weight=weights ) value_1 = t1.tree_.value value_2 = t2.tree_.value print(np.allclose(value_1, value_2)) ```

Expected Results The trees t1 and t2 should be identical.

Actual Results

ValueError: operands could not be broadcast together with shapes (195,395,2) (187,395,2) 

Versions Latest version from the master branch.