Novartis / torchsurv

Deep survival analysis made easy
https://opensource.nibr.com/torchsurv/
MIT License
47 stars 6 forks source link

how to predict time by using AFT model? #31

Closed abebe9849 closed 1 month ago

abebe9849 commented 2 months ago

Is it possible to predict survival times directly, as in the following torchlife code?


from torchlife.model import ModelAFT

model = ModelAFT('Gumbel')
model.fit(df)
surv_prob = model.predict(df)
mode_time = model.predict_time(df)
melodiemonod commented 2 months ago

Dear @abebe9849.

Thank you for your interest and inquiry.

At the moment we do not have a built-in function to predict survival time. However, in the Weibull AFT, the survival times follow a Weibull distribution. So you could easily predict survival time by generating random samples from the Weibull distribution.

Using synthetic data, this is how you could do it

import torch
from torchsurv.loss.weibull import survival_function

_ = torch.manual_seed(42)

# Generate synthetic data
n = 50
time = torch.randint(low=1, high=100, size=(n,))
log_params = torch.randn((n, 2))

# Compute the survival probability from weibull log parameters
surv_prob = survival_function(log_params, time)  # shape = (n,n)

# Generate predicted survival time
m = 10000  # number of predictions
log_scale, log_shape = log_params.unbind(1)
predicted_time = torch.distributions.weibull.Weibull(
    torch.exp(log_scale), torch.exp(log_shape)
).sample(
    (m,)
)  # shape = (m,n)
predicted_time_mean = predicted_time.mean(dim=0)  # shape = (n)

I hope this is useful.

Melodie