autonlab / auton-survival

Auton Survival - an open source package for Regression, Counterfactual Estimation, Evaluation and Phenotyping with Censored Time-to-Events
http://autonlab.github.io/auton-survival
MIT License
315 stars 74 forks source link

RuntimeError: expected scalar type Float but found Double #93

Closed zapaishchykova closed 1 year ago

zapaishchykova commented 1 year ago

Hi! Thanks for the great package! Any ideas on why the same dataset works for all models, except the dsm one?

here is the error log:

At hyper-param {'distribution': 'Weibull', 'k': 2, 'layers': [100, 100], 'learning_rate': 1e-05}
At fold: 0
100%|███████████████████████████████████| 10000/10000 [00:05<00:00, 1678.94it/s]
100%|██████████████████████████████████████████| 50/50 [00:00<00:00, 222.40it/s]
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In [126], line 34
     24 # Instantiate an auton_survival Experiment 
     25 #dsm  cph
     26 #Survival model choices include:
   (...)
     30 # |      - 'rsf' : Random Survival Forests [1] model
     31 # |      - 'cph' : Cox Proportional Hazards [2] model
     32 experiment = SurvivalRegressionCV(model='dsm', num_folds=6, 
     33                                     hyperparam_grid=param_grid)
---> 34 model = experiment.fit(x, outcomes, metric='ibs',horizons=times)
     36 times = np.quantile(outcomes.time[outcomes.event==1], [0.25, 0.5, 0.6]).tolist()
     38 # Fit the `experiment` object with the specified Cox model.
     39 #experiment = estimators.SurvivalModel(model='dsm')
     40 #model = experiment.fit(x, outcomes)

File /mnt/survival_notebooks/../auton-survival-master/auton_survival/experiments.py:164, in SurvivalRegressionCV.fit(self, features, outcomes, horizons, metric)
    162 model = SurvivalModel(self.model, random_seed=self.random_seed, **hyper_param)
    163 model.fit(features.loc[self.folds!=fold], outcomes.loc[self.folds!=fold])
--> 164 predictions = model.predict_survival(features.loc[self.folds==fold], times=horizons)
    166 score = survival_regression_metric(metric=self.metric, 
    167                                    outcomes=outcomes.loc[self.folds==fold],
    168                                    predictions=predictions,
    169                                    times=horizons,
    170                                    outcomes_train=outcomes.loc[self.folds!=fold])
    171 fold_scores.append(np.mean(score))

File /mnt/survival_notebooks/../auton-survival-master/auton_survival/estimators.py:701, in SurvivalModel.predict_survival(self, features, times)
    699   return _predict_rsf(self._model, features, times)
    700 elif self.model == 'dsm':
--> 701   return _predict_dsm(self._model, features, times)
    702 elif self.model == 'dcph':
    703   return _predict_dcph(self._model, features, times)

File /mnt/survival_notebooks/../auton-survival-master/auton_survival/estimators.py:420, in _predict_dsm(model, features, times)
    400 def _predict_dsm(model, features, times):
    402   """Predict survival at specified time(s) using the Deep Survival Machines.
    403 
    404   Parameters
   (...)
    417 
    418   """
--> 420   survival_predictions = model.predict_survival(x=features.values, t=times)
    421   survival_predictions = pd.DataFrame(survival_predictions, columns=times).T
    423   return __interpolate_missing_times(survival_predictions, times)

File /mnt/survival_notebooks/../auton-survival-master/auton_survival/models/dsm/__init__.py:415, in DSMBase.predict_survival(self, x, t, risk)
    413   t = [t]
    414 if self.fitted:
--> 415   scores = losses.predict_cdf(self.torch_model, x, t, risk=str(risk))
    416   return np.exp(np.array(scores)).T
    417 else:

File /mnt/survival_notebooks/../auton-survival-master/auton_survival/models/dsm/losses.py:518, in predict_cdf(model, x, t_horizon, risk)
    516 torch.no_grad()
    517 if model.dist == 'Weibull':
--> 518   return _weibull_cdf(model, x, t_horizon, risk)
    519 if model.dist == 'LogNormal':
    520   return _lognormal_cdf(model, x, t_horizon, risk)

File /mnt/survival_notebooks/../auton-survival-master/auton_survival/models/dsm/losses.py:335, in _weibull_cdf(model, x, t_horizon, risk)
    331 def _weibull_cdf(model, x, t_horizon, risk='1'):
    333   squish = nn.LogSoftmax(dim=1)
--> 335   shape, scale, logits = model.forward(x, risk)
    336   logits = squish(logits)
    338   k_ = shape

File /mnt/survival_notebooks/../auton-survival-master/auton_survival/models/dsm/dsm_torch.py:204, in DeepSurvivalMachinesTorch.forward(self, x, risk)
    196 def forward(self, x, risk='1'):
    197   """The forward function that is called when data is passed through DSM.
    198 
    199   Args:
   (...)
    202 
    203   """
--> 204   xrep = self.embedding(x)
    205   dim = x.shape[0]
    206   return(self.act(self.shapeg[risk](xrep))+self.shape[risk].expand(dim, -1),
    207          self.act(self.scaleg[risk](xrep))+self.scale[risk].expand(dim, -1),
    208          self.gate[risk](xrep)/self.temp)

File ~/miniconda3/envs/pycox310/lib/python3.10/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
   1126 # If we don't have any hooks, we want to skip the rest of the logic in
   1127 # this function, and just call forward.
   1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130     return forward_call(*input, **kwargs)
   1131 # Do not call functions when jit is used
   1132 full_backward_hooks, non_full_backward_hooks = [], []

File ~/miniconda3/envs/pycox310/lib/python3.10/site-packages/torch/nn/modules/container.py:139, in Sequential.forward(self, input)
    137 def forward(self, input):
    138     for module in self:
--> 139         input = module(input)
    140     return input

File ~/miniconda3/envs/pycox310/lib/python3.10/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
   1126 # If we don't have any hooks, we want to skip the rest of the logic in
   1127 # this function, and just call forward.
   1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130     return forward_call(*input, **kwargs)
   1131 # Do not call functions when jit is used
   1132 full_backward_hooks, non_full_backward_hooks = [], []

File ~/miniconda3/envs/pycox310/lib/python3.10/site-packages/torch/nn/modules/linear.py:114, in Linear.forward(self, input)
    113 def forward(self, input: Tensor) -> Tensor:
--> 114     return F.linear(input, self.weight, self.bias)

RuntimeError: expected scalar type Float but found Double

and the code:

# auton_survival cross-validation experiment.
from auton_survival.datasets import load_dataset
from auton_survival.preprocessing import Preprocessor
from auton_survival.metrics import survival_regression_metric
from auton_survival import estimators

param_grid = {'k' : [2],
              'distribution' : ['Weibull'],
              'learning_rate' : [1e-5],
              'layers' : [[100,100]]}

#outcomes, features = load_dataset(dataset='SUPPORT')
cat_feats = []
num_feats = list(features.columns)

preprocessor = Preprocessor(cat_feat_strat='ignore', num_feat_strat= 'mean') 
x = preprocessor.fit_transform(features, cat_feats=cat_feats, num_feats=num_feats,
                                one_hot=True, fill_value=-1)

x_val = preprocessor.fit_transform(features_val, cat_feats=cat_feats, num_feats=num_feats,
                                one_hot=True, fill_value=-1)

from auton_survival.experiments import SurvivalRegressionCV
# Instantiate an auton_survival Experiment 
#dsm  cph
#Survival model choices include:
# |      - 'dsm' : Deep Survival Machines [3] model
# |      - 'dcph' : Deep Cox Proportional Hazards [2] model
# |      - 'dcm' : Deep Cox Mixtures [4] model
# |      - 'rsf' : Random Survival Forests [1] model
# |      - 'cph' : Cox Proportional Hazards [2] model
experiment = SurvivalRegressionCV(model='dsm', num_folds=6, 
                                    hyperparam_grid=param_grid)
model = experiment.fit(x, outcomes, metric='ibs',horizons=times)

times = np.quantile(outcomes.time[outcomes.event==1], [0.25, 0.5, 0.6]).tolist()

# Fit the `experiment` object with the specified Cox model.
#experiment = estimators.SurvivalModel(model='dsm')
#model = experiment.fit(x, outcomes)

times_val = np.quantile(outcomes_val.time[outcomes_val.event==1], [0.25, 0.5, 0.6]).tolist()
out_risk = model.predict_risk(x_val, times)
out_survival = model.predict_survival(x_val, times)  

print("Times:",times_val)
print("Brier scores")
print(survival_regression_metric('brs', outcomes_val, 
                                     out_survival, 
                                     times=times_val))

print("Time Dependent Concordance Index")
print(survival_regression_metric('ctd', outcomes_val, 
                                     out_survival, 
                                     times=times_val))
chiragnagpal commented 1 year ago

Hi @zapaishchykova thanks for the kind words! It is a bit hard to tell exactly what might be the problem, but perhaps it has something to do with the floating point precision of your dataset ? can you check what the datatype of x and outcomes is ?

zapaishchykova commented 1 year ago

Hello, Type of x is dataframe with float32, outcomes is also dataframe with times as int and durations as ints

chiragnagpal commented 1 year ago

can you casting the features to float64 and the outcomes to int64 ?

zapaishchykova commented 1 year ago

that worked, thanks!

chiragnagpal commented 1 year ago

Thanks for flagging this. The API should be able to perform appropriate typecasting, I'll try to fix this soon in the upstream. Closing this for now.