The source code for Dr. Agent: Clinical Predictive Model via Mimicked Second Opinions
We provide the trained weights in ./saved_weights/
. You can obtain the reported performance in our paper by simply load the weights to the model by using following codes:
checkpoint = torch.load('./saved_weights/TASK_TO_TEST')
save_chunk = checkpoint['chunk']
model.load_state_dict(checkpoint['net'])
optimizer.load_state_dict(checkpoint['optimizer'])
We do not provide the MIMIC-III data itself. You must acquire the data yourself from https://mimic.physionet.org/. Specifically, download the CSVs. To run MIMIC-III bechmark tasks, you should first build benchmark dataset according to https://github.com/YerevaNN/mimic3-benchmarks/.
After building the benchmark dataset, there will be a directory data/{task}
for each created benchmark task. Then run extract_demo.py
to extract demographics from the dataset (change TASK
to specific task).
You can train Dr. Agent on different tasks by running corresponding files.
The minimum input you need to run Dr. Agent is the dataset directory and the model save directory
$ python train_decomp.py --data_path='./decompensation/data/' --save_path='./saved_weights/'
You can specify batch size --batch_size <integer>
, learning rate --lr <float>
and epochs --epochs <integer>
Additional hyper-parameters can be specified such as the dimension of RNN, using LSTM or GRU, etc. Detailed information can be accessed by
$ python train_decomp.py --help
When training is complete, it will output the performance of Dr. Agent on test dataset.
The minimal inputs to Dr. Agent model should contain:
You can directly use the model structure in ./model/
directory for different proposes:
model_decomp.py
: binary classification with outputs at each timestepmodel_los.py
: multi-label predictionmodel_mortality.py
: binary classification with output at the last timestepmodel_phenotyping.py
: multi-class predictionYou can also modify the structure for you specific tasks.
Junyi Gao, Cao Xiao, Lucas M Glass, Jimeng Sun,
Dr. Agent: Clinical predictive model via mimicked second opinions,
Journal of the American Medical Informatics Association, ocaa074, https://doi.org/10.1093/jamia/ocaa074