Closed abebe9849 closed 1 month ago
Dear @abebe9849.
Thank you for your interest and inquiry.
At the moment we do not have a built-in function to predict survival time. However, in the Weibull AFT, the survival times follow a Weibull distribution. So you could easily predict survival time by generating random samples from the Weibull distribution.
Using synthetic data, this is how you could do it
import torch
from torchsurv.loss.weibull import survival_function
_ = torch.manual_seed(42)
# Generate synthetic data
n = 50
time = torch.randint(low=1, high=100, size=(n,))
log_params = torch.randn((n, 2))
# Compute the survival probability from weibull log parameters
surv_prob = survival_function(log_params, time) # shape = (n,n)
# Generate predicted survival time
m = 10000 # number of predictions
log_scale, log_shape = log_params.unbind(1)
predicted_time = torch.distributions.weibull.Weibull(
torch.exp(log_scale), torch.exp(log_shape)
).sample(
(m,)
) # shape = (m,n)
predicted_time_mean = predicted_time.mean(dim=0) # shape = (n)
I hope this is useful.
Melodie
Is it possible to predict survival times directly, as in the following torchlife code?