@turambar @bpark738
I was attempting to write the lab today, and in order to do that I like to describe each step.
I struggled to grasp the dataset and the transformation, perhaps if one of you did a pass through the example and commented almost everything for me in a gist that would help.
Here is what I mean
My comments/questions will start with //TH so you can grep for them
public static void main(String[] args) throws IOException, InterruptedException {
// STEP 0: Flags controlling which data
// 0 for removing Time and Elapsed columns; 1 for removing Time; 2 for removing Elapsed
//TH what does the data look like before we do this, what is Time and Elapsed, maybe data sample
int remove = 0;
int numLabelClasses = 2;
boolean resampled = false; // If true use resampled data
//TH I get iterators, no need to explain these
DataSetIterator trainData;
DataSetIterator validData;
DataSetIterator testData;
if(resampled){
//TH Grandpa needs some help here, resampled in what way, I assume simplify by less timesteps ?
NB_INPUTS-=2;
featuresDir = new File(baseDir, "resampled");
SequenceRecordReader trainFeatures = new CSVSequenceRecordReader(1, ",");
trainFeatures.initialize( new NumberedFileInputSplit(featuresDir.getAbsolutePath() + "/%d.csv", 0, NB_TRAIN_EXAMPLES - 1));
SequenceRecordReader trainLabels = new CSVSequenceRecordReader();
trainLabels.initialize(new NumberedFileInputSplit(labelsDir.getAbsolutePath() + "/%d.csv", 0, NB_TRAIN_EXAMPLES - 1));
trainData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels,
BATCH_SIZE, numLabelClasses, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
// Load validation data
//TH I could review numbered file input split and all that, but a comment from either of you might be quicker
SequenceRecordReader validFeatures = new CSVSequenceRecordReader(1, ",");
validFeatures.initialize(new NumberedFileInputSplit(featuresDir.getAbsolutePath() + "/%d.csv", NB_TRAIN_EXAMPLES , NB_TRAIN_EXAMPLES + NB_VALID_EXAMPLES - 1));
SequenceRecordReader validLabels = new CSVSequenceRecordReader();
validLabels.initialize(new NumberedFileInputSplit(labelsDir.getAbsolutePath() + "/%d.csv", NB_TRAIN_EXAMPLES , NB_TRAIN_EXAMPLES + NB_VALID_EXAMPLES - 1));
validData = new SequenceRecordReaderDataSetIterator(validFeatures, validLabels,
BATCH_SIZE, numLabelClasses, false,SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
// Load test data
SequenceRecordReader testFeatures = new CSVSequenceRecordReader(1, ",");
testFeatures.initialize(new NumberedFileInputSplit(featuresDir.getAbsolutePath() + "/%d.csv", NB_TRAIN_EXAMPLES+ NB_VALID_EXAMPLES, NB_TRAIN_EXAMPLES + NB_VALID_EXAMPLES + NB_TEST_EXAMPLES - 1));
SequenceRecordReader testLabels = new CSVSequenceRecordReader();
testLabels.initialize(new NumberedFileInputSplit(labelsDir.getAbsolutePath() + "/%d.csv", NB_TRAIN_EXAMPLES+ NB_VALID_EXAMPLES, NB_TRAIN_EXAMPLES + NB_VALID_EXAMPLES + NB_TEST_EXAMPLES - 1));
testData = new SequenceRecordReaderDataSetIterator(testFeatures, testLabels,
BATCH_SIZE, numLabelClasses, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
}
else{
//TH I get the schema, no need to dive to deep for me on this.
Schema schema = new SequenceSchema.Builder()
.addColumnsDouble("Time","Elapsed","ALP").addColumnCategorical("ALPMissing")
.addColumnDouble("ALT").addColumnCategorical("ALTMissing").addColumnDouble("AST")
.addColumnCategorical("ASTMissing").addColumnDouble("Age").addColumnCategorical("AgeMissing")
.addColumnDouble("Albumin").addColumnCategorical("AlbuminMissing").addColumnDouble("BUN")
.addColumnCategorical("BUNMissing").addColumnDouble("Bilirubin").addColumnCategorical("BilirubinMissing")
.addColumnDouble("Cholesterol").addColumnCategorical("CholesterolMissing").addColumnDouble("Creatinine")
.addColumnCategorical("CreatinineMissing").addColumnDouble("DiasABP").addColumnCategorical("DiasABPMissing")
.addColumnDouble("FiO2").addColumnCategorical("FiO2Missing").addColumnDouble("GCS")
.addColumnCategorical("GCSMissing").addColumnCategorical("Gender0").addColumnCategorical("Gender1")
.addColumnDouble("Glucose").addColumnCategorical("GlucoseMissing").addColumnDouble("HCO3")
.addColumnCategorical("HCO3Missing").addColumnDouble("HCT").addColumnCategorical("HCTMissing")
.addColumnDouble("HR").addColumnCategorical("HRMissing").addColumnDouble("Height")
.addColumnCategorical("HeightMissing").addColumnCategorical("ICUType1").addColumnCategorical("ICUType2")
.addColumnCategorical("ICUType3").addColumnCategorical("ICUType4").addColumnDouble("K")
.addColumnCategorical("KMissing").addColumnDouble("Lactate").addColumnCategorical("LactateMissing")
.addColumnDouble("MAP").addColumnCategorical("MAPMissing").addColumnDouble("MechVent")
.addColumnCategorical("MechVentMissing").addColumnDouble("Mg").addColumnCategorical("MgMissing")
.addColumnDouble("NIDiasABP").addColumnCategorical("NIDiasABPMissing").addColumnDouble("NIMAP")
.addColumnCategorical("NIMAPMissing").addColumnDouble("NISysABP").addColumnCategorical("NISysABPMissing")
.addColumnDouble("Na").addColumnCategorical("NaMissing").addColumnDouble("PaCO2")
.addColumnCategorical("PaCO2Missing").addColumnDouble("PaO2").addColumnCategorical("PaO2Missing")
.addColumnDouble("Platelets").addColumnCategorical("PlateletsMissing").addColumnDouble("RespRate")
.addColumnCategorical("RespRateMissing").addColumnDouble("SaO2").addColumnCategorical("SaO2Missing")
.addColumnDouble("SysABP").addColumnCategorical("SysABPMissing").addColumnDouble("Temp")
.addColumnCategorical("TempMissing").addColumnDouble("TroponinI").addColumnCategorical("TroponinIMissing")
.addColumnDouble("TroponinT").addColumnCategorical("TroponinTMissing").addColumnDouble("Urine")
.addColumnCategorical("UrineMissing").addColumnDouble("WBC").addColumnCategorical("WBCMissing")
.addColumnDouble("Weight").addColumnCategorical("WeightMissing").addColumnDouble("pH")
.addColumnCategorical("pHMissing").build();
TransformProcess transformProcess;
//TH so we are removing some or more depending on a variable that is set, please comment that for me, or split into separate classes, I know you hate redundant code, I sort of love it, keep it simple for the masses I say. Either way, do what is efficient and workable
if(remove == 0){
transformProcess = new TransformProcess.Builder(schema).removeColumns("Time", "Elapsed").build();
NB_INPUTS-=2;
}
else if(remove == 1){
transformProcess = new TransformProcess.Builder(schema).removeColumns("Time").build();
NB_INPUTS-=1;
}
else if(remove == 2){
transformProcess = new TransformProcess.Builder(schema).removeColumns("Elapsed").build();
NB_INPUTS-=1;
}
else{
transformProcess = new TransformProcess.Builder(schema).build();
}
// Load training data
SequenceRecordReader trainFeatures = new CSVSequenceRecordReader(1, ",");
trainFeatures.initialize( new NumberedFileInputSplit(featuresDir.getAbsolutePath() + "/%d.csv", 0, NB_TRAIN_EXAMPLES - 1));
SequenceRecordReader trainLabels = new CSVSequenceRecordReader();
trainLabels.initialize(new NumberedFileInputSplit(labelsDir.getAbsolutePath() + "/%d.csv", 0, NB_TRAIN_EXAMPLES - 1));
TransformProcessSequenceRecordReader trainRemovedFeatures = new TransformProcessSequenceRecordReader(trainFeatures, transformProcess);
trainData = new SequenceRecordReaderDataSetIterator(trainRemovedFeatures, trainLabels,
BATCH_SIZE, numLabelClasses, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
// Load validation data
SequenceRecordReader validFeatures = new CSVSequenceRecordReader(1, ",");
validFeatures.initialize(new NumberedFileInputSplit(featuresDir.getAbsolutePath() + "/%d.csv", NB_TRAIN_EXAMPLES , NB_TRAIN_EXAMPLES + NB_VALID_EXAMPLES - 1));
SequenceRecordReader validLabels = new CSVSequenceRecordReader();
validLabels.initialize(new NumberedFileInputSplit(labelsDir.getAbsolutePath() + "/%d.csv", NB_TRAIN_EXAMPLES , NB_TRAIN_EXAMPLES + NB_VALID_EXAMPLES - 1));
TransformProcessSequenceRecordReader validRemovedFeatures = new TransformProcessSequenceRecordReader(validFeatures, transformProcess);
validData = new SequenceRecordReaderDataSetIterator(validRemovedFeatures, validLabels,
BATCH_SIZE, numLabelClasses, false,SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
// Load test data
SequenceRecordReader testFeatures = new CSVSequenceRecordReader(1, ",");
testFeatures.initialize(new NumberedFileInputSplit(featuresDir.getAbsolutePath() + "/%d.csv", NB_TRAIN_EXAMPLES+ NB_VALID_EXAMPLES, NB_TRAIN_EXAMPLES + NB_VALID_EXAMPLES + NB_TEST_EXAMPLES - 1));
SequenceRecordReader testLabels = new CSVSequenceRecordReader();
testLabels.initialize(new NumberedFileInputSplit(labelsDir.getAbsolutePath() + "/%d.csv", NB_TRAIN_EXAMPLES+ NB_VALID_EXAMPLES, NB_TRAIN_EXAMPLES + NB_VALID_EXAMPLES + NB_TEST_EXAMPLES - 1));
TransformProcessSequenceRecordReader testRemovedFeatures = new TransformProcessSequenceRecordReader(testFeatures, transformProcess);
testData = new SequenceRecordReaderDataSetIterator(testRemovedFeatures, testLabels,
BATCH_SIZE, numLabelClasses, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
}
// STEP 1: ETL/vectorization
// STEP 2: Model configuration and initialization
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(RANDOM_SEED)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.learningRate(LEARNING_RATE)
.weightInit(WeightInit.XAVIER)
.updater(Updater.ADAM)
.graphBuilder()
.addInputs("trainFeatures")
.setOutputs("predictMortality")
.addLayer("L1", new GravesLSTM.Builder()
.nIn(NB_INPUTS)
.nOut(lstmLayerSize)
.activation(Activation.TANH)
.build(),
"trainFeatures")
.addLayer("predictMortality", new RnnOutputLayer.Builder(LossFunctions.LossFunction.XENT)
.activation(Activation.SOFTMAX)
.weightInit(WeightInit.XAVIER)
.nIn(lstmLayerSize).nOut(numLabelClasses).build(),"L1")
.pretrain(false).backprop(true)
.build();
// STEP 3 Performance monitoring
ComputationGraph model = new ComputationGraph(conf);
model.init();
model.setListeners(new ScoreIterationListener(10));
// STEP 4 Model training
for( int i=0; i<NB_EPOCHS; i++ ){
model.fit(trainData); // implicit inner loop over minibatches
// loop over batches in training data to compute training AUC
//TH would like to define AUC, and ROC and maybe point to docs
//TH not exactly the correct place for this, but they will want to discuss what the output is, so step back from the output layer which I assume is softmax or what is needed for classification, what is the output at each neuron, like one layer back? It gets a collection of sequences in, emits a collection of sequences?
//TH Thanks
ROC roc = new ROC(100);
trainData.reset();
while(trainData.hasNext()){
DataSet batch = trainData.next();
INDArray[] output = model.output(batch.getFeatures());
roc.evalTimeSeries(batch.getLabels(), output[0]);
}
log.info("EPOCH " + i + " TRAIN AUC: " + roc.calculateAUC());
roc = new ROC(100);
while (validData.hasNext()) {
DataSet batch = validData.next();
INDArray[] output = model.output(batch.getFeatures());
roc.evalTimeSeries(batch.getLabels(), output[0]);
}
log.info("EPOCH " + i + " VALID AUC: " + roc.calculateAUC());
trainData.reset();
validData.reset();
}
ROC roc = new ROC(100);
while (testData.hasNext()) {
DataSet batch = testData.next();
INDArray[] output = model.output(batch.getFeatures());
roc.evalTimeSeries(batch.getLabels(), output[0]);
}
log.info("***** Test Evaluation *****");
log.info("{}", roc.calculateAUC());
}
}
@turambar @bpark738 I was attempting to write the lab today, and in order to do that I like to describe each step.
I struggled to grasp the dataset and the transformation, perhaps if one of you did a pass through the example and commented almost everything for me in a gist that would help.
Here is what I mean
My comments/questions will start with //TH so you can grep for them