soda-inria / hazardous

Competing Risks and Survival Analysis
https://soda-inria.github.io/hazardous/
MIT License
45 stars 11 forks source link

Add the c-index with IPCW #71

Open Vincent-Maladiere opened 1 month ago

Vincent-Maladiere commented 1 month ago

What does this PR propose?

This PR proposes to add the c-index as defined in [1]. I think this is ready to be reviewed for merging, with some questions/suggestions in the TODO section below.

show maths Screenshot 2024-07-09 at 17 07 48 where: Screenshot 2024-07-09 at 17 07 33 and Screenshot 2024-07-09 at 17 07 43 and Screenshot 2024-07-09 at 17 09 41 where $M$ is the probability of incidence of the event of interest.

TODO

cc @ogrisel @GaelVaroquaux @juAlberge @glemaitre

[1] Wolbers, M., Blanche, P., Koller, M. T., Witteman, J. C., & Gerds, T. A. (2014). Concordance for prognostic models with competing risks.

Vincent-Maladiere commented 1 month ago

The CI for the doc fails because the previous boosting tree model is missing. This should be fixed when https://github.com/soda-inria/hazardous/pull/53 is merged.

Vincent-Maladiere commented 1 month ago

Update on performance

Our implementation is 100x slower than scikit-survival concordance_index_ipcw. This is due to the weight computing (the IPCWs) inside the BalancedTree, which lifelines doesn't perform.

code benchmark ```python import numpy as np import pandas as pd from time import time from lifelines import CoxPHFitter from lifelines.datasets import load_kidney_transplant from sklearn.model_selection import train_test_split from hazardous.metrics._concordance_index import _concordance_index_incidence_report df = load_kidney_transplant() # make the dataset 100x times longer for benchmarking purposes df = pd.concat([df] * 100, axis=0) df_train, df_test = train_test_split(df, stratify=df["death"]) cox = CoxPHFitter().fit(df_train, duration_col="time", event_col="death") t_min, t_max = df["time"].min(), df["time"].max() time_grid = np.linspace(t_min, t_max, 20) y_pred = 1 - cox.predict_survival_function(df_test, times=time_grid).T.to_numpy() y_train = df_train[["death", "time"]].rename(columns=dict( death="event", time="duration" )) y_test = df_test[["death", "time"]].rename(columns=dict( death="event", time="duration" )) tic = time() result = _concordance_index_incidence_report( y_test=y_test, y_pred=y_pred, time_grid=time_grid, taus=None, y_train=y_train, ) print(f"our implementation: {time() - tic:.2f}s") # scikit-survival from sksurv.metrics import concordance_index_ipcw def make_recarray(y): event, duration = y["event"].values, y["duration"].values return np.array( [(event[i], duration[i]) for i in range(len(event))], dtype=[("e", bool), ("t", float)], ) tic = time() concordance_index_ipcw( make_recarray(y_train), make_recarray(y_test), y_pred[:, -1], tau=None, ) print(f"scikit-survival: {time() - tic:.2f}s") # lifelines from lifelines.utils import concordance_index concordance_index( event_times=y_test["duration"], predicted_scores=1 - y_pred[:, -1], event_observed=y_test["event"], ) print(f"lifelines: {time() - tic:.2f}s") ```

On a dataset with 20k rows:

our implementation: 18.10s
scikit-survival: 0.24s
lifelines: 0.27s

The flamegraph is quite clear about the culprit, being the list comprehension that computes the IPCW weight for each pair. When I remove the IPCWs, the performance becomes similar to lifelines.

Speedscope views of our implementation Screenshot 2024-07-24 at 18 46 38 Screenshot 2024-07-24 at 18 46 53

I tried to fix this performance issue using numba @jitclass on the BTree, but it is still very slow. I put the numba BTree on a separate draft branch for reference.

Conclusion

I only see two ways forward:

  1. Either my computation of the IPCW in the BTree is flawed, and we can fix the performance issue
  2. or the BTree is not adapted for our metric and we have to look at a non-optimized pairwise implementation like scikit-survival with a $O(n^2)$ instead of $n \log (n)$ time complexity. This would simplify the code base though.
jjerphan commented 1 month ago

Pinged by @Vincent-Maladiere, but have no time for it.

Random pile of pieces of advice:

GaelVaroquaux commented 1 month ago

No, don't use compiled languages, please. It will make release and distribution much harder.

On Jul 26, 2024, 13:46, at 13:46, Julien Jerphanion @.***> wrote:

Pinged by @Vincent-Maladiere, but have no time for it.

Random pile of pieces of advice:

  • find if a better algorithm exist first
  • profile to see what's the bottleneck
  • see if tree-based structures can be used from another library (e.g. pydatastructures
  • use another language (like Cython or C++) to implement the costly algorithmic part

-- Reply to this email directly or view it on GitHub: https://github.com/soda-inria/hazardous/pull/71#issuecomment-2252581805 You are receiving this because you were mentioned.

Message ID: @.***>

Vincent-Maladiere commented 1 month ago

After giving it some more thought, there is room for improvement with the current balanced tree design :

  1. When we don't use an IPCW estimator (like lifelines): $$W{ij,1} = W{ij,2} = 1$$
  2. When we use a non-conditional IPCW estimator (Kaplan-Meier, like scikit-survival): $$W{ij,1} = W{i,1} = \hat{G}(Ti) ^ 2 \space \mathrm{and} \space W{ij,2} = \hat{G}(T_i) \hat{G}(T_j) $$

However, when we use a conditional IPCW estimator (like Cox or SurvivalBoost), we have: $$W_{ij,1} = \hat{G}(T_i | X_i) \hat{G}(T_i | Xj) \space \mathrm{and} \space W{ij,2} = \hat{G}(T_i | X_i) \hat{G}(T_j | X_j)$$

In this case, the balanced tree is not adapted anymore, and we should use the naive implementation.

So, to make things simpler, I suggest we only implement the naive version for now, and eventually return to the balanced tree later, for the non-conditional and unweighted cases.

WDYT?

GaelVaroquaux commented 1 month ago

Sounds good to me. We can always iterate if needed

On Jul 26, 2024, 18:38, at 18:38, Vincent M @.***> wrote:

After giving it some more thought, there is room for improvement with the current balanced tree design :

  1. When we don't use an IPCW estimator (like lifelines): $$W{ij,1} = W{ij,2} = 1$$
  2. When we use a non-conditional IPCW estimator (Kaplan-Meier, like scikit-survival): $$W{ij,1} = W{i,1} = \hat{G}(Ti) ^ 2 \space \mathrm{and} \space W{ij,2} = \hat{G}(T_i) \hat{G}(T_j) $$

However, when we use a conditional IPCW estimator (like Cox or SurvivalBoost), we have: $$W_{ij,1} = \hat{G}(T_i | X_i) \hat{G}(T_i | Xj) \space \mathrm{and} \space W{ij,2} = \hat{G}(T_i | X_i) \hat{G}(T_j | X_j)$$

In this case, the balanced tree is not adapted anymore, and we should use the naive implementation.

So, to make things simpler, I suggest we only implement the naive version for now, and eventually return to the balanced tree later, for the non-conditional and unweighted cases.

WDYT?

-- Reply to this email directly or view it on GitHub: https://github.com/soda-inria/hazardous/pull/71#issuecomment-2253113408 You are receiving this because you were mentioned.

Message ID: @.***>

Vincent-Maladiere commented 2 weeks ago

Here is the revised version. When used on a survival dataset, it gives identical results to scikit-survival, with a slightly better time complexity.

cindex_duration

Vincent-Maladiere commented 1 week ago

This PR is now ready to be reviewed :)