Novartis / torchsurv

Deep survival analysis made easy
http://opensource.nibr.com/torchsurv/
MIT License
64 stars 7 forks source link

Plot Kaplan-Meier Curves Based on Model Predictions Using torchsurv? #63

Open tadongguk opened 1 week ago

tadongguk commented 1 week ago

I am using the torchsurv library for survival analysis and want to plot Kaplan-Meier (KM) curves based on my model's predictions, not the baseline KM curve. I’m looking for guidance on how to generate these KM curves using the predictions from torchsurv. Could anyone provide advice or example code for this? download

melodiemonod commented 6 days ago

Hi @tadongguk ,

Many thanks for your question and your interest in our package.

The 'Kaplan Meier curve' is a plot of the Kaplan Meier estimates of the survival function. Although our package has functions to calculate the Kaplan Meier estimates, we don't have functions to plot them. I would refer to the package lifelines for that (see code below).

You said that you wanted to plot your model predictions, and by this I understand that you wish to plot the survival function that your model estimates. The Weibull model indeed allows one to estimate the survival function given the trained Weibull parameters. Note that the Cox model does not.

We do not currently have functions to plot the survival function estimates after using the Weibull model. But I will try to show you how you could do it below. I will use gbsg2 that we used in the tutorial. I will plot the estimated survival function given the three grades tumor grades. Before running the code below, please run the tutorial, Section 2 (no need to run section 1).

Obtain KM estimates and plot them

# Obtain the Kaplan Meier estimate and plot them

from lifelines import KaplanMeierFitter
from matplotlib.pyplot import subplots

fig, ax = subplots(figsize=(8, 8))

# Grade I
km_gradeI = KaplanMeierFitter()
df_train_subset = df_train[
    (df_train["tgrade_II"] == 0.0) & (df_train["tgrade_III"] == 0.0)
]
km_gradeI.fit(df_train_subset["time"], df_train_subset["cens"])
km_gradeI.plot(label="Grade I", ax=ax)

# Grade II
km_gradeII = KaplanMeierFitter()
df_train_subset = df_train[
    (df_train["tgrade_II"] == 1.0) & (df_train["tgrade_III"] == 0.0)
]
km_gradeII.fit(df_train_subset["time"], df_train_subset["cens"])
km_gradeII.plot(label="Grad II", ax=ax)

# Grade III
km_gradeIII = KaplanMeierFitter()
df_train_subset = df_train[
    (df_train["tgrade_II"] == 0.0) & (df_train["tgrade_III"] == 1.0)
]
km_gradeIII.fit(df_train_subset["time"], df_train_subset["cens"])
km_gradeIII.plot(label="Grade III", ax=ax)

4c2957ec-219b-4af4-8ec4-47d2d920cfeb

Obtain survival function estimates from Weibull model and plot them

n = df_onehot.shape[0] # 686

# create Grade I, Grade II and Grade III covariates of the size of the data set
tgradeI = torch.tensor([[0., 0.]], dtype = torch.float32).repeat(n , 1) # torch.Size([686, 2])
tgradeII = torch.tensor([[1., 0.]], dtype = torch.float32).repeat(n , 1) # torch.Size([686, 2])
tgradeIII = torch.tensor([[0., 1.]], dtype = torch.float32).repeat(n , 1) # torch.Size([686, 2])
tgrade = torch.cat((tgradeI, tgradeII, tgradeIII), dim = 0)

# append Grades to each covariates replica
x_pred = torch.tensor(df_onehot.copy().drop(["cens", "time", "tgrade_II", "tgrade_III"], axis =1).values, dtype = torch.float32) # torch.Size([686, 9])
x_pred = torch.cat((x_pred.repeat(3, 1), tgrade), dim=1) # torch.Size([2058, 9])

# weibull parameters for covariates x_pred
weibull_model.eval()
with torch.no_grad():
    log_params = weibull_model(x_pred) # torch.Size([2058, 2])

# time grid
time_step = 10.0
time_grid = torch.arange(0., max(df_onehot['time']) + time_step, time_step) # torch.Size([267])

# Compute the survival probability on time grid
surv_pred = [survival_function(log_params, time=t) for t in time_grid]
surv_pred = torch.stack(surv_pred) # torch.Size([267, 2058])

# split by grade
surv_pred_tgradeI = surv_pred[:,0:n]
surv_pred_tgradeII = surv_pred[:,n:n*2]
surv_pred_tgradeIII = surv_pred[:,n*2:n*3]

# Take median (you can also take CI to quantify uncertainty)
surv_summary_tgradeI = surv_pred_tgradeI.median(dim=1).values
surv_summary_tgradeII = surv_pred_tgradeII.median(dim=1).values
surv_summary_tgradeIII = surv_pred_tgradeIII.median(dim=1).values

# Create a plot
plt.figure(figsize=(10, 6))

# Plot median survival predictions 
plt.plot(time_grid.numpy(), surv_summary_tgradeI.numpy(), label='Grade 1', color='blue')
plt.plot(time_grid.numpy(), surv_summary_tgradeII.numpy(), label='Grade 2', color='orange')
plt.plot(time_grid.numpy(), surv_summary_tgradeIII.numpy(), label='Grade 3', color='green')

# Optional: Overlay with KM survival estimates
# plt.plot(km_gradeI.timeline, km_gradeI.survival_function_.values, label='Grade 1', color='blue', linestyle='--')
# plt.plot(km_gradeII.timeline, km_gradeII.survival_function_.values, label='Grade 2', color='orange', linestyle='--')
# plt.plot(km_gradeIII.timeline, km_gradeIII.survival_function_.values, label='Grade 3', color='green', linestyle='--')

# Add labels and title
plt.xlabel('Time')
plt.ylabel('Survival Probability')
plt.title('Survival Function')
plt.legend()
plt.grid()
plt.show()

dabb0c41-81e1-49d4-a33e-730117d698a6

tadongguk commented 3 days ago

melodiemonod

Thank you for your response. I see that when using neg_partial_log_likelihood as the loss, the log_hz value is obtained from cox_model(x). Is log_hz the risk score here?

If I want to estimate the survival function, do I only need to add a function to calculate the cumulative baseline hazard using the Breslow method?

melodiemonod commented 3 days ago

Is log_hz the risk score here?

This is documented in TorchSurv's website. Quoting the documentation:

The log hazard function for the Cox proportional hazards model has the form: $$\log \lambdai (t) = \log \lambda{0}(t) + \log \theta_i$$ where $\log \theta_i$ is the log relative hazard (argument log_hz).

The baseline hazard $\lambda_{0}(t)$ is not required to evaluate the partial log likelihood. As you mentioned, there are methods to estimate it, but they are not provided by our package. Nonetheless, should you be able to obtain $$\int0^t \lambda{0}(s) ds$$, then you can evaluate the survival function of $i$ at time $t$ with

$$S_i(t) = \exp(- \int_0^t \lambda_i (s) ds )$$ $$= \exp(- \theta_i \times \int0^t \lambda{0}(s) ds )$$ $$= \exp(- \exp(\log \theta_i) \times \int0^t \lambda{0}(s) ds )$$