mehta-lab / microDL

3D virtual staining with 2D and 2.5D U-Nets
BSD 3-Clause "New" or "Revised" License
27 stars 7 forks source link

Inference for gunpowder dataloading #190

Closed Christianfoley closed 1 year ago

Christianfoley commented 1 year ago

As of the latest commits of the gunpowder dataloading -> pytorch implementation PR, preprocessing and training using the gunpowder backend + and HCS compatible zarr store for data IO is in a working (but untested) state.

Config file formats have been documented here and an example config file that refers to a real Vero dataset on ESS can be found at /hpc/projects/CompMicro/projects/virtualstaining/torch_microDL/config_files/2022_11_01_VeroMemNuclStain/gunpowder_testing_12_13/torch_config_25D.yml Note that preprocessing will make modifications to any zarr store that it is called on in place. If you want to preserve a raw data copy, the workflow would be to make a copy store before running preprocessing on it.

For inference to be run, a script (or microDL module) needs to be developed that does the following:

  1. Builds a gunpowder pipeline that includes no augmentations. This can be done similar to the way it is done here in the data visualization script by utilizing the TorchTrainer.generate_dataloaders() method, or manually by calling the multi_zarr_source to build your sources and adding a RandomProvider, RandomLocation, Reject, PrepMaskRoi, FlatFieldCorrect, and Normalize node to the pipeline as is done here, (in that order). All these nodes have pretty comprehensive docstrings.
  2. Ensure that the data this pipeline pulls from is validation or testing data, and was not used for training. The positions used for "train", "test", and "val" in the last training are currently saved in the top-level .zattrs, so they can be queried in inference. Note, these positions are overwritten every new training. In the future, they should also be saved as a .cvs in the model directory of the training.
  3. Randomly sample positions from this pipeline (examples of doing this in a standalone pipeline, from a dataloader that feeds the model, and from a dataloader that does not feed the model) run them through the model, and compute metrics on the outputs and record visualizations (an example here) of them.

If the pipeline you use contains a FlatFieldCorrect and Normalize node, you don't need to worry about normalizing the data from that pipeline before running inference on it. The values for flatfield correction and normalization are stored in the .zattrs and untracked arrays of each position, and they are accessed on the fly by these nodes. You might, however, want to unzscore the result/prediction of the inference to restore dynamic range for visualization. This is tricky, since you need to know which position each sample came from, which might require a new custom node.

Soorya19Pradeep commented 1 year ago

Hello @Christianfoley. Thank you for working on the inference code! I see you have made some changes on dataloader function leading to line 151 in micro_dl/torch_unet/utils/training.py to call the recorded data split. I am unable to run training as the code errors due to the lack of the recorded data split it is looking for. According to my understanding the data split to train/val/test is computed and stored during training and loaded back and used during inference. It seems it needs a check to make sure it is called during inference and the data split is computed instead during training.

Christianfoley commented 1 year ago

@Soorya19Pradeep I pushed changes without updating the documentation, apologies for not keeping you in the loop. I will update the documentation and config file examples.

Implementing inference added some config parameters: "use_recorded_data_split" is now a mandatory parameter. As of today's discussion this will change, see issue #203.

Christianfoley commented 1 year ago

Completed with recent commit

Soorya19Pradeep commented 1 year ago

@Christianfoley , I tried out the inference script and I ran into an error.

Traceback (most recent call last): File "/home/soorya.pradeep/microDL/micro_dl/cli/torch_inference_script.py", line 103, in main(args.config, args.gpu, args.gpu_mem_frac) File "/home/soorya.pradeep/microDL/micro_dl/cli/torch_inference_script.py", line 96, in main torch_predictor.generate_dataloaders() File "/home/soorya.pradeep/microDL/micro_dl/torch_unet/utils/inference.py", line 115, in generate_dataloaders torch_data_container = ds.InferenceDatasetContainer( File "/home/soorya.pradeep/microDL/micro_dl/torch_unet/utils/dataset.py", line 255, in init ) = gp_utils.multi_zarr_source( File "/home/soorya.pradeep/microDL/micro_dl/utils/gunpowder_utils.py", line 319, in multi_zarr_source source_position = get_zarr_source_position(source) File "/home/soorya.pradeep/microDL/micro_dl/utils/gunpowder_utils.py", line 125, in get_zarr_source_position zarr_source = zarr_source[0] IndexError: list index out of range

I thought it evolved from loading the data-split from metadata. Is that correct? If so am I missing something on the config or from the training step? Thank you!

Christianfoley commented 1 year ago

Hi Soorya, it's a bit hard to tell, but it looks like the problem stems from the multi_zarr_source not giving the InferenceDataset enough sources. Because inference requires that each slice in each position be predicted exactly once, and we can't just random sample, the InferenceDataset needs to build many identical pipelines with copies of the ZarrSource referring to each position. This error is telling me that it did not receive many copies, only one.

It is possible that this is a config issue. Can you share the config that generated this error?

Here is the config I used to run integration tests, it has some new parameters for inference. /hpc/projects/comp_micro/projects/virtualstaining/torch_microDL/config_files/2022_11_01_VeroMemNuclStain/gunpowder_testing_12_13/

Let's move this discussion underneath the PR.