Novartis / torchsurv

Deep survival analysis made easy
http://opensource.nibr.com/torchsurv/
MIT License
69 stars 7 forks source link

Censoring #36

Closed StatMixedML closed 5 months ago

StatMixedML commented 5 months ago

Thanks for your effort and for making the implementation available.

I just came across torchsurv and was wondering if it the Weibull can account for censored data.

Thanks!

melodiemonod commented 5 months ago

Dear @StatMixedML,

Yes, it can take right-censored data. Please see the descriptions of the function's arguments in the documentation

StatMixedML commented 5 months ago

Dear @melodiemonod,

Many thanks for your reply. It is working and gives the same results for the NLL as the Tensorflow-Probability implementation in the case of right-censoring.

Any plans to also include left and interval censoring?

melodiemonod commented 5 months ago

Super, thank you for letting us know! We will keep a log of it to demonstrate reproducibility. If you want to share your code below, you are most welcome to.

Any plans to also include left and interval censoring?

You are working with the first release of the package now. We are waiting to get requests from users and then will start addressing them by the most popular in future releases. So possibly in the future if there is interest :)

StatMixedML commented 5 months ago

@melodiemonod Ok thanks!

Please find below the MWE:

###
# Imports
###
import torch
from torchsurv.loss.weibull import neg_log_likelihood
from torch.distributions import Weibull as Weibull_PyTorch

import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
import numpy as np

###
# Data (taken from https://probflow.readthedocs.io/en/stable/examples/time_to_event.html#)
###
np.random.seed(123)
tf.random.set_seed(123)

# Generate time-to-event data
n_samples = 1000
n_obs = 800
n_feat = 2  
x = np.random.rand(n_obs, n_feat)
w = np.array([[1.], [-1.]]).astype('float32')
b = np.random.rand(1, 1)
y = tfd.Exponential(tf.nn.softplus(x@w + b)).sample().numpy()

# Generate Weibull Parameters
concentration = np.exp(np.random.rand(n_obs,1))
scale = np.exp(np.random.rand(n_obs,1))

# Simulate purchase times
t_stop = 10 #time of data collection
t1 = t_stop*np.random.rand(n_obs)
t2 = t1 + y[:, 0]
cix = t2>t_stop

# set some observations to censored
# Neg. wait time w/ still no observation
y[cix, 0] = t1[cix]-t_stop

###
# TensorFlow
###
tf.random.set_seed(123)

dist = tfd.Weibull(concentration = concentration, scale = scale)

# Likelihoods of observed time-to-events
obs_ll = dist.log_prob(y)[y>=0]

# Likelihoods of right censored data
right_ll = dist.log_survival_function(-y)[y<0]

# Total NLL
nll_tfp = tf.reduce_sum(obs_ll) + tf.reduce_sum(right_ll) 

# Samples
samples_tfp = dist.sample(n_samples)

print(f"NLL: {-nll_tfp.numpy():.4f} \n Sample-Mean: {tf.math.reduce_mean(samples_tfp).numpy():.4f}")

###
# Torchsurv
###
torch.manual_seed(123)

# Log-Params
log_params = torch.tensor(np.concatenate([np.log(scale), np.log(concentration)], axis=1), dtype = torch.float32)

# Time and Event 
event_var = torch.tensor(y >= 0).reshape(-1,)
time_var = torch.tensor(y).abs().reshape(-1,)

# NLL
nll_pytorch = neg_log_likelihood(log_params, event=event_var, time=time_var, reduction="sum")

# Initialize Distribution for sampling
weibull_pytorch = Weibull_PyTorch(scale=torch.tensor(scale), concentration=torch.tensor(concentration))

print(f"NLL: {nll_pytorch.numpy():.4f} \nSample-Mean: {weibull_pytorch.sample((n_samples,)).numpy().mean():.4f}")
tcoroller commented 4 months ago

Thanks @StatMixedML for the suggestion - I created a backlog issue for future release