CamDavidsonPilon / lifelines

Survival analysis in Python
lifelines.readthedocs.org
MIT License
2.32k stars 551 forks source link

CI Table Not Outputting For All Classes In Column #1525

Closed sapaca closed 1 year ago

sapaca commented 1 year ago

For this code:

for name, grouped_df in df.groupby('mar'):
    kmf = KaplanMeierFitter(alpha=alpha)
    kmf.fit(durations=grouped_df['week'], event_observed=grouped_df['arrest'], label=name)

After running the fit above and executing this:

kmf.confidence_interval_

only 1 class is generated for the output table even though plotting will show the CI estimates (shaded area) for all classes in the column of interest. In the below output I would expect to see a 0_lower_0.95 and 0_upper_0.5 column as well.

1_lower_0.95 1_upper_0.95
0.0 1.000000 1.000000
12.0 0.873516 0.997320
23.0 0.857428 0.990427
26.0 0.834689 0.981385
39.0 0.811287 0.970985
42.0 0.788081 0.959609
43.0 0.742813 0.934739
46.0 0.720761 0.921486
52.0 0.720761 0.921486
CamDavidsonPilon commented 1 year ago

I suspect it's because you are re-assigning the kmf variable, so the last iteration defines what kmf is (which is a KaplanMeierFitter trained only on data where mar=1).

I suggest breaking this into something like:

kmfs = {}
for name, grouped_df in df.groupby('mar'):
    kmf = KaplanMeierFitter(alpha=alpha)
    kmf.fit(durations=grouped_df['week'], event_observed=grouped_df['arrest'], label=name)
    kmfs[name] = kmf

And then you have access to all kfms:

print(kmfs[0]. confidence_interval_)
print(kmfs[1]. confidence_interval_)
sapaca commented 1 year ago

Perfect, thank you. That works for me, I appreciate it!