[x] Histogram normalization across the whole dataset. MVP implementation for this task. The simplest way to normalize the data is to calculate the distribution of each marker across the entire training dataset, then normalize the data.
[x] Image-specific normalization. Investigate if there is a way to normalize each image without needing to know the distribution of the whole dataset.
Dataset creation:
[x] MIBI Segmentation Data Loader. MPV implementation for this task. This would create a segmentation based scheme for data loading
[x] Plot utils to visualize positive/negative/unspecific (1/0/-1) cells on top of a marker channel
[ ] Multi-channel segmentation data loader. This would allow the model to be trained on multiple channels and labels simultaneously
[x] Fixed channels to aid prediction. This would allow the user to specify a list of channels that would always be included in a fixed order as model inputs, such as membrane/nucleus channels, to give cell morphology or co-expression information
[x] Remove ambiguous cells from training data. There are certain cells that are likely borderline expression for specific markers. We should consider adding an option to remove them from the training data completely. Obviously removing hard examples will by definition improve performance. So we should evaluate if this improves performance on the rest of the dataset independent of removing these tricky ones.
[ ] Normalize data once and story it on disc/do it in-place to avoid the compute overhead from normalizing it each time we generate a new tfrecord dataset.
Image augmentation
[x] Basic image augmentation for segmentation model. MVP implementation for this task. This would augment the input image channel separate from the input masks
[x] Multi-channel augmentation. This would handle multiple input channels of image data.
[x] Use augmentations on the GPU to speed up training (e.g. via kerasCV or tensorlayer)
[ ] Investigate effect of augmentation on val loss across dataset sizes, in line with this augmentation paper
[ ] Determine if MixUp provides an additional performance boost over current augmentations
Image generator
[x] Basic image generator for segmentation model. MVP implementation for this task. This would load the tf record, pass the data to the augmenter, generate an appropriate batch size, and then pass it to the model.
Model architecture
[x] Basic architecture for training. MVP implementation for this task. We can probably use something similar to what was in Dave's example notebook, a very simple CNN with a couple layers
[x] Port over popular baseline models. We should include a couple of the common models, ResNet50, efficientnet-v2, etc, as options for the backbone. We can then swap them in and out easily during training to evaluate their impacts on performance
[ ] Use BatchNorm layer in BEN scheme to cope with image-level normalization.
[ ] Evaluate different crop sizes to see if there is a performance advantage for smaller or larger images
[ ] Evaluate model robustness to changes in image resolution
[ ] Evaluate different model backbones to assess differences in performance
[x] Inverse exponential moving average class frequency weighting in the loss function
[x] Uniform sampling of positive and negative cells. We can experiment with this instead of / together with loss weighting. Do this on a per-cell type basis? Per marker basis?
[ ] Use active learning to identify images which are most informative, preferentially include those
[x] Remove tiles with 0 segmented cells
[x] Remove tiles with 0 positive cells
[x] Remove tiles with number of positive cells less than x percentile, where percentile is determined on a per cell/per channel/global basis
[x] Use clean sample selection procedure from ProMix Naive to sort out noisy labels
[ ] Return promix predictions on a per-image basis to use for downstream tasks/model training/published dataset
[ ] Use augmentation consistency training or ProMix' Label Guessing SSL approach to incorporate noisy samples.
Model training
[ ] Training pipeline walk-through notebook, that shows the loaded data, the augmented data and the output of the model after x iterations
[x] Add grokking training scheme: high regularization and cosine annealing learning rate scheduler
[x] Evaluate loss only on pixels within cells. We can experiment to see if restricting the loss to plausible regions results in a smoother manifold/better consistency of the trained model
[ ] Include an uncertainty score in addition to classification score. Giving the model the opportunity to express confidence in its own prediction can result in better calibration/accuracy
[x] Alternative loss functions. We can experiment to see if weighting rare cell types, weighting positive examples instead of negatives, or other schemes results in improved performance given the imbalance and noise in the ground truth labels. For example focal loss or something similar
[ ] Use object-centric loss function: calculate loss as the average of the loss of the segments instead of average of all pixels to overweight small cells and underweight big cells.
[ ] Determine relationship between dataset size and accuracy across cell types
[ ] Optimize training speed by using XLA compilation, mixed-precision training and distributed parallel training strategies
[ ] Add style vector (see cellpose) for different datasets / markers
Model evaluation
[x] Evaluate pixel-level accuracy scores. MVP implementation for this task. We can report the pixel-level accuracy for individual channels to assess how the model is doing
[x] Evaluate cell-level accuracy. To evaluate cell-level accuracy, we'll need to integrate the scores of the individual pixels within each cell, then set a threshold for calling a cell positive. Changing this threshold is a potential hyperparamter we can use to tune performance
[x] Estimate what fraction of errors are true errors, and what fraction is mislabeled data
Model error visualization
[x] Split cell-level metrics by channel type and cell type (not jointly) and use a facet plot to plot metrics for all subsets. Some useful visualizations: a separate accuracy plot for each channel, aggregated over all cell types. A separate accuracy plot for each cell type, aggregated over all channels. A separate accuracy plot for each channel in each cell type.
[ ] Generate crops of informative cell assignments. 1) High confidence predictions that are correct. 2) High confidence predictions that are incorrect. 3) Low confidence predictions that are correct. 4) Low confidence predictions that are incorrect. Crops would include the input data, probability mask, and metadata.
Model post-processing
[x] Constant thresholding across channels. MVP implementation for this task. To go from marker probabilities to cells, the output will need to be post-processed. The simplest scheme would be to average the softmax outputs from each pixel, and use a threshold of 0.5 to call each marker positive/negative. Cells that have the binary yes/no classification correct for all markers would be accurate, otherwise not.
[ ] Mapping of cells to closest marker combination. Rather than a simple thresholding, it may make sense to do some sort of distance-based metric to determine which out of a set of pre-determined expression patterns the cell most closely resembles
[ ] Integration with FlowSOM. Instead of defining a marker matrix with the possible cell types, we feed the output of the model as the input to cell-based FlowSOM to generate clusters
[x] Write out predictions for sample selection (noisy/not noisy) as an additional value in the tfrecord and as an additional column in the cell table.
Cropping model:
[ ] Classification data loader. This would frame the task as the classification of a single cell, rather than pixelwise classification of the whole image
[ ] Basic image augmentation for classification model. This would augment the input image channel separate from the input masks
[ ] Basic image generator for classification model. This would load the tf record, pass the data to the augmenter, generate an appropriate batch size, and then pass it to the model.
The following is a summary of potential improvements/features for the classification model
Experiment Plan
Data normalization:
Dataset creation:
Image augmentation
Image generator
Model architecture
Training data selection
Model training
Model evaluation
Model error visualization
Model post-processing
Cropping model: