Added functionality to automatically run inference on lists of datasets, a script to generate test set predictions and fixed a bug in ModelBuilder.dset_marker_filter.
How did you implement your changes
Add function ModelBuilder.predict_data_list that takes in a list of tf.data.Datasets, predicts Nimbus scores and calculates per-cell Nimbus scores along with dataset, fov, cell ID, cell type, marker and silver standard labels and saves this as a .csv
Added a script hyperparameter_search.py to use the above function to calculate validation predictions, calculate f1 scores for different pos/neg thresholds (individually for every marker in every dataset) and save the optimal thresholds. Then it runs inference on the test set and uses the found thresholds to assign pos/neg class.
Fixed a bug in ModelBuilder.dset_marker_filter. This function filters the dataset and throws out samples with a specific dataset and marker combination (to exclude the two falsely silver standard labeled channels in the decidua dataset). The bug was that the predicate for filtering compared a byte string b'CD4' within a tensor with a regular string 'CD4' which always evaluated as False. This didn't come up in the tests, because I didn't correctly name the test function with prefix test so pytest didn't execute it. Now both mistakes are fixed.
What is the purpose of this PR?
Added functionality to automatically run inference on lists of datasets, a script to generate test set predictions and fixed a bug in
ModelBuilder.dset_marker_filter
.How did you implement your changes
ModelBuilder.predict_data_list
that takes in a list oftf.data.Datasets
, predicts Nimbus scores and calculates per-cell Nimbus scores along with dataset, fov, cell ID, cell type, marker and silver standard labels and saves this as a .csvhyperparameter_search.py
to use the above function to calculate validation predictions, calculate f1 scores for different pos/neg thresholds (individually for every marker in every dataset) and save the optimal thresholds. Then it runs inference on the test set and uses the found thresholds to assign pos/neg class.ModelBuilder.dset_marker_filter
. This function filters the dataset and throws out samples with a specific dataset and marker combination (to exclude the two falsely silver standard labeled channels in the decidua dataset). The bug was that the predicate for filtering compared a byte stringb'CD4'
within a tensor with a regular string'CD4'
which always evaluated asFalse
. This didn't come up in the tests, because I didn't correctly name the test function with prefixtest
so pytest didn't execute it. Now both mistakes are fixed.Remaining issues
None