Novartis / torchsurv

Deep survival analysis made easy
https://opensource.nibr.com/torchsurv/
MIT License
47 stars 6 forks source link

Question: Can Torchsurv Handle Time Series Data for Survival Analysis? #42

Open Lamgayin opened 4 days ago

Lamgayin commented 4 days ago

Dear authors, hello, I am a newcomer in the field of survival analysis. I have a research topic regarding the application of time series models to survival analysis, and I am not sure if torchsurv supports such functionality.thx a lot

melodiemonod commented 1 day ago

Hey @Lamgayin , thank you for your question. For now, time-dependent covariates are not supported by TorchSurv. We will add this as a potential new feature. Best regards, Melodie

tcoroller commented 5 hours ago

Hi @Lamgayin,

TorchSurv allows you to use any model architecture and any data by design. If you want to use time series model to predict survival, you can simply use a RNN (many to one type) or transformer architecture to fit your longitudinal data. As long as your model outputs a single (or two for Weibull) estimate, you can then connect any torchsurv functions (loss and/or metrics) to it.

I attached a simple code example to illustrate how to use time series with RNN with TorchSurv. Here I am using 10 features across 5 time steps. The RNN is outputting a single estimate for each sample. It is then very easy to connect TorchSurv loss function and metrics from there.

Good luck!


import torch
from torchsurv.loss import cox
from torchsurv.metrics.cindex import ConcordanceIndex

# Parameters
input_size = 10
output_size = 1
num_layers = 2
seq_length = 5
batch_size = 8

# make random boolean events
events = torch.rand(batch_size) > 0.5
print(events)  # tensor([ True, False,  True,  True, False, False,  True, False])

# make random positive time to event
time = torch.rand(batch_size) * 100
print(time)  # tensor([32.8563, 38.3207, 24.6015, 72.2986, 19.9004, 65.2180, 73.2083, 21.2663])

# Create simple RNN model
rnn = torch.nn.RNN(input_size, output_size, num_layers)
inputs = torch.randn(seq_length, batch_size, input_size)
h0 = torch.randn(num_layers, batch_size, output_size)

# Forward pass time series input
outputs, _ = rnn(inputs, h0)
estimates = outputs[-1]  # Keep only last predictions, many to one approach
print(estimates.size())  # torch.Size([8, 1])
print(f"Estimate shape for {batch_size} samples = {estimates.size()}")  # Estimate shape for 8 samples = torch.Size([8, 1])

loss = cox.neg_partial_log_likelihood(estimates, events, time)
print(f"loss = {loss}, has gradient = {loss.requires_grad}")  # loss = 1.0389232635498047, has gradient = True

cindex = ConcordanceIndex()
print(f"c-index = {cindex(estimates, events, time)}")  # c-index = 0.20000000298023224