angelolab / Nimbus

Other
12 stars 1 forks source link

Multi dataset #48

Closed JLrumberger closed 1 year ago

JLrumberger commented 1 year ago

What is the purpose of this PR?

This PR closes #39 by adding multi-dataset training for PromixNaive and ModelBuilder. This PR also makes it possible to define the constituents of the input data explicitly (i.e. [marker_channel, binary_mask] as input or [marker_channel, binary_mask, nuclei_img, membrane_img]), so that we can do experiments on what works best here.

How did you implement your changes

Multi dataset training I changed many small parts in ModelBuilder and PromixNaive regarding dataset preparation. Before, all this code worked on single datasets, whereas now it works on lists of datasets. The training datasets are in the end collapsed via tf.data.Dataset.sample_from_datasets to a single train_dataset that samples from all its constituent datasets. In addition, I made PromixNaive.class_wise_loss_selection calculate percentile thresholds for every marker and every dataset individually. Finally, I added multi-dataset capabilities to the evaluation_scrip.py.

Training with additional input channels Static class functions ModelBuilder.prep_batches and PromixNaive.prep_batches_promix are now defined in class functions ModelBuilder.gen_prep_batches_fn and PromixNaive.gen_prep_batches_promix_fn. The latter functions take in a list of constituent channel names (i.e. "mplex_img", "binary_mask", "nuclei_img", "membrane_img") and return a function that takes an example dict and returns batches where the input data consists of constituent channels.

Fixed some bugs in metrics.py Function calc_metrics threw errors when some of the cells had label==2 in the ground-truth activity. I added code that excludes these cells from metric calculation.

Remaining issues There could remain some issues when using this code in production since I touched many things for this PR..

ngreenwald commented 1 year ago

In terms of checking that this didn't inadvertently break anything, here are some options:

  1. Use this branch to train a model on the TONIC dataset only, and check that the final confidence percentiles are equivalent to training with the single model code
  2. Use this branch to train a model with the TONIC dataset plus 1 additional image from a different dataset, and check that the percentiles on the TONIC channels are equivalent.
  3. Duplicate the TONIC dataset into TONIC_1 and TONIC_2, and check that training with two datasets gives thresholds that are similar (but not exactly equal) to each other. Check that these probabilities are also similar to what we got before
  4. For any of the above models, check the accuracy on the test dataset

Let me know what you think, or if you have other ideas for good sanity checks.

JLrumberger commented 1 year ago

There seems to be a considerable performance gap between baseline models trained with the old and the new codebase and I try to find the root cause. It seems to be a problem with the newly generated tfrecord dataset that includes membrane and nuclei channel. I ran two identical models with the new code-base on the new and old dataset and it produced pretty different results..

JLrumberger commented 1 year ago

Checked the datasets and they are identical except for the newly introduced nuclei and membrane channels.

ngreenwald commented 1 year ago

So could it be a problem with how those channels are constructed? You’re still seeing a difference in performance for the same model on the old and new dataset?

JLrumberger commented 1 year ago

All these models use the same training hyperparameters and these are my observations:

I am a bit suspicious now that it could be due to multi-gpu training. All models with the new codebase are trained on 4 GPUs..

ngreenwald commented 1 year ago

I agree, we should try and change as few things as possible given that this PR is already so big.

Were the models trained on the new dataset also using the nuclei/membrane channel? Or it was generated but not used during training?

JLrumberger commented 1 year ago

I agree, we should try and change as few things as possible given that this PR is already so big.

Were the models trained on the new dataset also using the nuclei/membrane channel? Or it was generated but not used during training?

The models were trained without the nuclei/membrane channel.

ngreenwald commented 1 year ago

Okay, so to summarize:

What's your hypothesis for why new codebase, new dataset looks different from new codebase, old dataset? Given that they are identical? Could this just be random noise from one bad training run?

JLrumberger commented 1 year ago

Yep, I hope it's because of the initialization. I restarted all setups on single GPU, so tomorrow I'll know more.

ngreenwald commented 1 year ago

Sounds good!

JLrumberger commented 1 year ago

Oh god, it has been the multi GPU training. The loss for the baseline models looks normal when trained on a single GPU. PromixNaive models look as good as before. I just have one small little change to commit and then you can review and we can merge it in.

ngreenwald commented 1 year ago

In addition to the training loss, will you run the test metrics and make sure there aren't any differences in performance?

Which of the above scenarios did you test out? Keeping the TONIC dataset the same, but training with the multi-dataset code?

JLrumberger commented 1 year ago

I tested the following scenarios:

train baseline on tonic only with old dataset, f1=0.694 train baseline on tonic only with new dataset, f1=0.688 train baseline on tonic and decidua, sampling=[0.9999, 0.0001], f1=0.668 train baseline on tonic and tonic, sampling=[0.5, 0.5], f1=0.7214

train promix on tonic only with new dataset, f1=0.656 train promix on tonic and decidua, sampling=[0.9999, 0.0001], f1=0.633 train promix on tonic, decidua, msk_colon, sampling=[0.34, 0.33, 0.33], f1=0.634

so far I only looked at the validation loss and compared it to older models. I can calculate validation metrics and look for differences there, but it should look alright given the validation loss is similar as before.

ngreenwald commented 1 year ago

Okay great! Sounds like you rooted out the problem. I think it'll just be better to be super sure that there aren't any lingering issues before we move forward.

JLrumberger commented 1 year ago

@ngreenwald I am happy with this PR now.