Open Lamgayin opened 4 days ago
Hey @Lamgayin , thank you for your question. For now, time-dependent covariates are not supported by TorchSurv. We will add this as a potential new feature. Best regards, Melodie
Hi @Lamgayin,
TorchSurv
allows you to use any model architecture and any data by design. If you want to use time series model to predict survival, you can simply use a RNN
(many to one type) or transformer
architecture to fit your longitudinal data. As long as your model outputs a single (or two for Weibull) estimate, you can then connect any torchsurv functions (loss
and/or metrics
) to it.
I attached a simple code example to illustrate how to use time series with RNN with TorchSurv
. Here I am using 10 features
across 5 time steps
. The RNN is outputting a single estimate for each sample. It is then very easy to connect TorchSurv
loss function and metrics from there.
Good luck!
import torch
from torchsurv.loss import cox
from torchsurv.metrics.cindex import ConcordanceIndex
# Parameters
input_size = 10
output_size = 1
num_layers = 2
seq_length = 5
batch_size = 8
# make random boolean events
events = torch.rand(batch_size) > 0.5
print(events) # tensor([ True, False, True, True, False, False, True, False])
# make random positive time to event
time = torch.rand(batch_size) * 100
print(time) # tensor([32.8563, 38.3207, 24.6015, 72.2986, 19.9004, 65.2180, 73.2083, 21.2663])
# Create simple RNN model
rnn = torch.nn.RNN(input_size, output_size, num_layers)
inputs = torch.randn(seq_length, batch_size, input_size)
h0 = torch.randn(num_layers, batch_size, output_size)
# Forward pass time series input
outputs, _ = rnn(inputs, h0)
estimates = outputs[-1] # Keep only last predictions, many to one approach
print(estimates.size()) # torch.Size([8, 1])
print(f"Estimate shape for {batch_size} samples = {estimates.size()}") # Estimate shape for 8 samples = torch.Size([8, 1])
loss = cox.neg_partial_log_likelihood(estimates, events, time)
print(f"loss = {loss}, has gradient = {loss.requires_grad}") # loss = 1.0389232635498047, has gradient = True
cindex = ConcordanceIndex()
print(f"c-index = {cindex(estimates, events, time)}") # c-index = 0.20000000298023224
Dear authors, hello, I am a newcomer in the field of survival analysis. I have a research topic regarding the application of time series models to survival analysis, and I am not sure if torchsurv supports such functionality.thx a lot