jaredleekatzman / DeepSurv

DeepSurv is a deep learning approach to survival analysis.
MIT License
566 stars 166 forks source link

TypeError: nesterov_momentum() got an unexpected keyword argument 'logger' #63

Open xiao751 opened 4 years ago

xiao751 commented 4 years ago

When I get to this row, I get an error. Why is that? "metrics = model.train(train_data, n_epochs=n_epochs, logger=logger, update_fn=update_fn)" [INFO] Training CoxMLP

TypeError Traceback (most recent call last)

in 18 19 # If you have validation data, you can add it as the second parameter to the function ---> 20 metrics = model.train(train_data, n_epochs=n_epochs, logger=logger, update_fn=update_fn) D:\anaconda\lib\site-packages\deepsurv\deep_surv.py in train(self, train_data, valid_data, n_epochs, validation_frequency, patience, improvement_threshold, patience_increase, verbose, update_fn, **kwargs) 366 reached, looks at validation improvement to increase patience or 367 early stop. --> 368 improvement_threshold: percentage of improvement needed to increase 369 patience. 370 patience_increase: multiplier to patience if threshold is reached. D:\anaconda\lib\site-packages\deepsurv\deep_surv.py in _get_train_valid_fn(self, L1_reg, L2_reg, learning_rate, **kwargs) 208 updates = update_fn( 209 scaled_grads, self.params, **kwargs --> 210 ) 211 else: 212 updates = update_fn( D:\anaconda\lib\site-packages\deepsurv\deep_surv.py in _get_loss_updates(self, L1_reg, L2_reg, update_fn, max_norm, deterministic, **kwargs) 179 Returns Theano expressions for the network's loss function and parameter 180 updates. --> 181 182 Parameters: 183 L1_reg: float for L1 weight regularization coefficient. TypeError: nesterov_momentum() got an unexpected keyword argument 'logger'
jaredleekatzman commented 3 years ago

Are you still having this issue? It looks like the logger is being passed from the .train() function to the _get_train_valid_fn() which is then passing it up update_fn through the **kwargs.

Adding a logger=None parameter to the function signature of _get_train_valid_fn might fix the problem.

For example:

_get_train_valid_fn(self, L1_reg, L2_reg, learning_rate, **kwargs)

Would become:

_get_train_valid_fn(self, L1_reg, L2_reg, learning_rate, logger=None, **kwargs)