autonlab / auton-survival

Auton Survival - an open source package for Regression, Counterfactual Estimation, Evaluation and Phenotyping with Censored Time-to-Events
http://autonlab.github.io/auton-survival
MIT License
315 stars 74 forks source link

mat1 and mat2 shapes cannot be multiplied (10424x58 and 57x100) #105

Closed tommyvcc closed 1 year ago

tommyvcc commented 1 year ago

Dear

When I run the deep cox mixtures and deep cox proportional hazard on my data, I get the following error: RuntimeError Traceback (most recent call last) ~\AppData\Local\Temp\ipykernel_14400\3101778186.py in 23 24 # Obtain survival probabilities for validation set and compute the Integrated Brier Score ---> 25 predictions_val = model.predict_survival(x_val, times) 26 metric_val = survival_regression_metric('ibs', y_val, predictions_val, times, y_tr) 27 models.append([metric_val, model])

~\Desktop\auton_survival 2\auton-survival-master\auton_survival\estimators.py in predict_survival(self, features, times) 701 return _predict_dsm(self._model, features, times) 702 elif self.model == 'dcph': --> 703 return _predict_dcph(self._model, features, times) 704 elif self.model == 'dcm': 705 return _predict_dcm(self._model, features, times)

~\Desktop\auton_survival 2\auton-survival-master\auton_survival\estimators.py in _predict_dcph(model, features, times) 232 times = times.ravel().tolist() 233 --> 234 return model.predict_survival(x=features.values, t=times) 235 236 def _fit_cph(features, outcomes, val_data, random_seed, **hyperparams):

~\Desktop\auton_survival 2\auton-survival-master\auton_survival\models\cph__init__.py in predict_survival(self, x, t) 231 t = [t] 232 --> 233 scores = predict_survival(self.torch_model, x, t) 234 return scores 235

~\Desktop\auton_survival 2\auton-survival-master\auton_survival\models\cph\dcph_utilities.py in predict_survival(model, x, t) 144 145 model, breslow_spline = model --> 146 lrisks = model(x).detach().cpu().numpy() 147 148 unique_times = breslow_spline.baselinesurvival.x

~\AppData\Roaming\Python\Python39\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, *kwargs) 1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1193 or _global_forward_hooks or _global_forward_pre_hooks): -> 1194 return forward_call(input, **kwargs) 1195 # Do not call functions when jit is used 1196 full_backward_hooks, non_full_backward_hooks = [], []

~\Desktop\auton_survival 2\auton-survival-master\auton_survival\models\cph\dcph_torch.py in forward(self, x) 27 def forward(self, x): 28 ---> 29 return self.expert(self.embedding(x)) 30 31 class DeepRecurrentCoxPHTorch(DeepCoxPHTorch):

~\AppData\Roaming\Python\Python39\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, *kwargs) 1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1193 or _global_forward_hooks or _global_forward_pre_hooks): -> 1194 return forward_call(input, **kwargs) 1195 # Do not call functions when jit is used 1196 full_backward_hooks, non_full_backward_hooks = [], []

~\AppData\Roaming\Python\Python39\site-packages\torch\nn\modules\container.py in forward(self, input) 202 def forward(self, input): 203 for module in self: --> 204 input = module(input) 205 return input 206

~\AppData\Roaming\Python\Python39\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, *kwargs) 1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1193 or _global_forward_hooks or _global_forward_pre_hooks): -> 1194 return forward_call(input, **kwargs) 1195 # Do not call functions when jit is used 1196 full_backward_hooks, non_full_backward_hooks = [], []

~\AppData\Roaming\Python\Python39\site-packages\torch\nn\modules\linear.py in forward(self, input) 112 113 def forward(self, input: Tensor) -> Tensor: --> 114 return F.linear(input, self.weight, self.bias) 115 116 def extra_repr(self) -> str:

RuntimeError: mat1 and mat2 shapes cannot be multiplied (10424x58 and 57x100)

I wonder if you can kindly help me with that

chiragnagpal commented 1 year ago

Is the model training correctly ? or do you get this error during training?

tommyvcc commented 1 year ago

Thank you very much for getting back to me

When I try the Cox proportional hazard, it worked very well and gave me all the results for CTD, IBS, Brier score and AUC.

However, when I tried the deep cox proportional hazard, I got the error after running the following code:

from auton_survival.estimators import SurvivalModel from auton_survival.metrics import survival_regression_metric from sklearn.model_selection import ParameterGrid

Define parameters for tuning the model

param_grid = {'bs' : [100, 200], 'learning_rate' : [ 1e-4, 1e-3], 'layers' : [ [100], [100, 100] ] }

params = ParameterGrid(param_grid)

Define the times for tuning the model hyperparameters and for evaluating the model

times = np.quantile(y_tr['time'][y_tr['event']==1], np.linspace(0.1, 0.99, 10)).tolist()

Perform hyperparameter tuning

models = [] for param in params: model = SurvivalModel('dcph', random_seed=0, bs=param['bs'], learning_rate=param['learning_rate'], layers=param['layers'])

# The fit method is called to train the model
model.fit(x_tr, y_tr)

# Obtain survival probabilities for validation set and compute the Integrated Brier Score 
predictions_val = model.predict_survival(x_val, times)
metric_val = survival_regression_metric('ibs', y_val, predictions_val, times, y_tr)
models.append([metric_val, model])

Select the best model based on the mean metric value computed for the validation set

metric_vals = [i[0] for i in models] first_min_idx = metric_vals.index(min(metric_vals)) model = models[first_min_idx][1]

tommyvcc commented 1 year ago

I think the model is trained correctly and the error happens after training the model and try to make prediction on the x_val

---> 25 predictions_val = model.predict_survival(x_val, times)

tommyvcc commented 1 year ago

20230306_182158

I am sending a picture of when the error starts to appear

So, once it starts training , it shows the following: 16781270860094487564288015332298

tommyvcc commented 1 year ago

Sorted It was some of the variables have some categories with very limited numbers