jaredleekatzman / DeepSurv

DeepSurv is a deep learning approach to survival analysis.
MIT License
584 stars 173 forks source link

Reproducibility of model training #77

Closed zhxiaokang closed 2 years ago

zhxiaokang commented 2 years ago

I tried using random.seed() to fix the randomness of model training as follows

# Create an instance of DeepSurv using the hyperparams defined above
model = deepsurv.DeepSurv(**hyperparams)

experiment_name = 'test_deepsurv'
logdir = './logs/tensorboard/'
logger = TensorboardLogger(experiment_name, logdir=logdir)

# Now we train the model
update_fn=lasagne.updates.nesterov_momentum # The type of optimizer to use.

import random
random.seed(1234)
metrics = model.train(train_data, valid_data, n_epochs=n_epochs, logger=logger, update_fn=update_fn)

# Plot the training / validation curves
viz.plot_log(metrics)

But the above solution did not work. I still got very different models if I repeated the above codes.

Is there a way to fix the training randomness and ensure reproducibility?

yudeng2015 commented 2 years ago

I tried np.random.RandomState(seed=1), it seems to be working

zhxiaokang commented 2 years ago

Hi @yudeng2015 thank you for the suggestion. But I tried several times, and unfortunately np.random.RandomState(seed=1) can't keep the model stable...

yudeng2015 commented 2 years ago

Hi @zhxiaokang , I followed this post https://github.com/keras-team/keras/issues/2743 and tried the follow code:


seed_value= 0

# 1. Set the `PYTHONHASHSEED` environment variable at a fixed value
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ['PYTHONHASHSEED']=str(seed_value)

# 2. Set the `python` built-in pseudo-random generator at a fixed value
import random
random.seed(seed_value)

# 3. Set the `numpy` pseudo-random generator at a fixed value
import numpy as np
np.random.seed(seed_value)

# 4. Set the `tensorflow` pseudo-random generator at a fixed value
import tensorflow as tf
tf.random.set_seed(seed_value)
# for later versions: 
tf.compat.v1.set_random_seed(seed_value)

# 5. Configure a new global `tensorflow` session
from keras import backend as K
#session_conf = tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
#sess = tf.Session(graph=tf.get_default_graph(), config=session_conf)
#K.set_session(sess)
# for later versions:
session_conf = tf.compat.v1.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
sess = tf.compat.v1.Session(graph=tf.compat.v1.get_default_graph(), config=session_conf)
tf.compat.v1.keras.backend.set_session(sess)

I ran the above codes first before importing any other libraries. Does this fix the issue?

zhxiaokang commented 2 years ago

Hi @yudeng2015 , thank you! That works!