mehta-lab / microDL

3D virtual staining with 2D and 2.5D U-Nets
BSD 3-Clause "New" or "Revised" License
27 stars 7 forks source link

Testing of pytorch_implementation branch #175

Closed Christianfoley closed 1 year ago

Christianfoley commented 1 year ago

I've made this into an issue so we can have a coherent thread in which we can reference parts of the code and I can answer questions about the workflow.

I have just done some preliminary testing:

There are more thorough instructions and documentation of the new torch config file in the branch's micro_dl/torch_unet/readme.md. The general 'gist' of the PyTorch workflow is that right now we still need valid preprocessing, training, and inference configs for some set of data, but most of (see example training configs) the parameters pertaining to training and model initiation can be ignored, and are instead part of the torch_config.yml file. This is clunky, but it is only a temporary solution for testing, and once we've determined the integrity of the PyTorch models we can phase out the other config files.

Christianfoley commented 1 year ago

Updates after deep work: I've updated the handling of split_samples.json metadata; it is now saved in the same directory as the model weights are saved in, as is done in the tf version. Inference should also now automatically load the the validation samples when running training and inference in sequence on Bruno..

One issue I was aware of is that the naive caching I implemented saves samples as they are provided by the tensorflow dataset, meaning that it will not work with random augmentations, as they are recomputed 'on the fly' in the backend of the dataset. I've updated the caching system so that the functionality is automatically disabled to test augmentations, and can be force re-enabled by specifying caching: True in the training section of the torch_config.ymlconfig file if augmentations are not being used and faster training is desired.

Christianfoley commented 1 year ago

Soorya discovered some issues with the way that augmentations had been implemented. I updated the code to fix these issues in this commit. @JohannaRahm if you plan to test with augmentations I think you will need to pull these changes.

Christianfoley commented 1 year ago

Potential memory issue where gpu memory usage spikes (from around 4gb to ~30gb) mid-epoch.

Christianfoley commented 1 year ago

For 2.5d unets, inference will re-run same prediction on entire z-stack for every slice in the stack, instead of just running once on the stack. This should be changed to run once on entire z-stack for each position provided for 2.5d unets, and once on each slice and position provided for 2d unets.

JohannaRahm commented 1 year ago

I've started a test with a model that translates phase to membrane stain (1 channel in, 1 channel out) and used a working set of 3 config files (preprocess, train, inference) in combination with the torch config. In preprocessing the images were masked with the otsu method. In the torch_unet readme.md it says that the parameter mask should be mostly set to true, but setting this parameter to true gives an error. I collected the error message and config files in the following locations:

What does mask False mean? Will tiles with mostly background not be filtered out during taining?

Christianfoley commented 1 year ago

I collected the error message and config files in the following locations:

Hi Johanna. Could you show the error message that you are getting? I didn't see a copy of it in the directory. I would try to reproduce it with your config files but I do not have permissions to access to the files or the data/metadata they reference (which is caused by a known issue).

What does mask False mean? Will tiles with mostly background not be filtered out during taining?

The mask parameter in the torch config file should really be hidden (and renamed as well), thanks for bringing this up. It applies only to actually masking parts of the samples during training if the network task is segmentation. We don't plan to highlight this functionality, so it probably shouldn't be specified in the default config file. Tiles with mostly background are filtered out during training if the 'masks' section is present in the preprocessing config. This is performed by the tensorflow-based dataset module, and should function in the exact same way as tensorflow-based microDL.

JohannaRahm commented 1 year ago

Let's chat about how to fix the permission errors.

Thanks for the explanation about the mask parameter. Does this mean that the parameter should be kept as False in case of a regression task?

The error messages are located here, but I am inserting the error messages directly in this post, as this is more clear.

mask = true, otsu

(/hpc/user_apps/comp_micro/conda_envs/microdl_torch) [johanna.rahm@gpu-b-2 microDL]$ python /home/johanna.rahm/microDL/micro_dl/cli/torch_train_script.py  --config /hpc/projects/comp_micro/projects/virtualstaining/2022_microDL_nuc_mem/configfiles/tests/pytorch_branch/mem/torch_config_2D_mem.yml
Using TensorFlow backend.
/home/johanna.rahm/microDL/micro_dl/input/dataset.py:135: UserWarning: Warning: tf-dependent BaseDataSet to be replaced with GunPowder in 2.1.0
  warnings.warn('Warning: tf-dependent BaseDataSet to be replaced with GunPowder in 2.1.0')
/home/johanna.rahm/microDL/micro_dl/input/inference_dataset.py:14: UserWarning: InferenceDataSet class to be replaced with gunpowder in 2.1.0
  warnings.warn('InferenceDataSet class to be replaced with gunpowder in 2.1.0')

Epoch 0:
     training 1/13791 [>_________________________________________________]Traceback (most recent call last):
  File "/home/johanna.rahm/microDL/micro_dl/cli/torch_train_script.py", line 42, in <module>
    trainer.train()
  File "/home/johanna.rahm/microDL/micro_dl/torch_unet/utils/training.py", line 187, in train
    output = self.model(input_, validate_input = True)
  File "/hpc/user_apps/comp_micro/conda_envs/microdl_torch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/johanna.rahm/microDL/micro_dl/torch_unet/networks/Unet2D.py", line 173, in forward
    x = self.down_conv_blocks[i](x, validate_input = validate_input)
  File "/hpc/user_apps/comp_micro/conda_envs/microdl_torch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/johanna.rahm/microDL/micro_dl/torch_unet/networks/layers/ConvBlock2D.py", line 232, in forward
    x = self.conv_list[i](x)
  File "/hpc/user_apps/comp_micro/conda_envs/microdl_torch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/hpc/user_apps/comp_micro/conda_envs/microdl_torch/lib/python3.6/site-packages/torch/nn/modules/conv.py", line 446, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/hpc/user_apps/comp_micro/conda_envs/microdl_torch/lib/python3.6/site-packages/torch/nn/modules/conv.py", line 443, in _conv_forward
    self.padding, self.dilation, self.groups)
RuntimeError: Expected 4-dimensional input for 4-dimensional weight[16, 1, 3, 3], but got 5-dimensional input of size [64, 64, 1, 256, 256] instead
(/hpc/user_apps/comp_micro/conda_envs/microdl_torch) [johanna.rahm@gpu-b-2 microDL]$ 

mask = true, unimodal

(/hpc/user_apps/comp_micro/conda_envs/microdl_torch) [johanna.rahm@gpu-b-2 microDL]$ python /home/johanna.rahm/microDL/micro_dl/cli/torch_train_script.py  --config /hpc/projects/comp_micro/projects/virtualstaining/2022_microDL_nuc_mem/configfiles/tests/pytorch_branch/mem/torch_config_2D_mem.yml
Using TensorFlow backend.
/home/johanna.rahm/microDL/micro_dl/input/dataset.py:135: UserWarning: Warning: tf-dependent BaseDataSet to be replaced with GunPowder in 2.1.0
  warnings.warn('Warning: tf-dependent BaseDataSet to be replaced with GunPowder in 2.1.0')
/home/johanna.rahm/microDL/micro_dl/input/inference_dataset.py:14: UserWarning: InferenceDataSet class to be replaced with gunpowder in 2.1.0
  warnings.warn('InferenceDataSet class to be replaced with gunpowder in 2.1.0')

Epoch 0:
Traceback (most recent call last):
  File "/home/johanna.rahm/microDL/micro_dl/cli/torch_train_script.py", line 42, in <module>
    trainer.train()
  File "/home/johanna.rahm/microDL/micro_dl/torch_unet/utils/training.py", line 171, in train
    for current, minibatch in enumerate(self.train_dataloader): 
  File "/hpc/user_apps/comp_micro/conda_envs/microdl_torch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 521, in __next__
    data = self._next_data()
  File "/hpc/user_apps/comp_micro/conda_envs/microdl_torch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 561, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/hpc/user_apps/comp_micro/conda_envs/microdl_torch/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/hpc/user_apps/comp_micro/conda_envs/microdl_torch/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/johanna.rahm/microDL/micro_dl/torch_unet/utils/dataset.py", line 139, in __getitem__
    sample_target = transform(sample_target)
  File "/home/johanna.rahm/microDL/micro_dl/torch_unet/utils/dataset.py", line 316, in __call__
    masks.append(mask_utils.create_unimodal_mask(sample[i,0,0]))
  File "/home/johanna.rahm/microDL/micro_dl/utils/masks.py", line 110, in create_unimodal_mask
    thr = get_unimodal_threshold(input_image)
  File "/home/johanna.rahm/microDL/micro_dl/utils/masks.py", line 92, in get_unimodal_threshold
    assert best_threshold > -np.inf, 'Error in unimodal thresholding'
AssertionError: Error in unimodal thresholding
(/hpc/user_apps/comp_micro/conda_envs/microdl_torch) [johanna.rahm@gpu-b-2 microDL]$ 
Soorya19Pradeep commented 1 year ago

I get the same error as Johanna once I have trained a 2D model to predict nucleus from phase and run the inference.

It seems like an error related to the dimension of input tiles created during preprocessing.

time 0 position 5:   0%|                                                                                                                             | 0/4 [00:00<?, ?it/sTraceback (most recent call last):                                                                                                                                          
  File "micro_dl/cli/torch_inference_script.py", line 93, in <module>
    image_predictor.run_prediction()
  File "/hpc/mydata/soorya.pradeep/microDL/micro_dl/inference/image_inference.py", line 726, in run_prediction
    chan_slice_meta,
  File "/hpc/mydata/soorya.pradeep/microDL/micro_dl/inference/image_inference.py", line 638, in predict_2d
    input_image=cur_input
  File "/hpc/mydata/soorya.pradeep/microDL/micro_dl/torch_unet/utils/inference.py", line 63, in predict_large_image
    pred = torch.unsqueeze(model(img_tensor), -3)
  File "/hpc/user_apps/comp_micro/conda_envs/microdl_torch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/hpc/mydata/soorya.pradeep/microDL/micro_dl/torch_unet/networks/Unet2D.py", line 173, in forward
    x = self.down_conv_blocks[i](x, validate_input = validate_input)
  File "/hpc/user_apps/comp_micro/conda_envs/microdl_torch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/hpc/mydata/soorya.pradeep/microDL/micro_dl/torch_unet/networks/layers/ConvBlock2D.py", line 251, in forward
    x = self.conv_list[i](x)
  File "/hpc/user_apps/comp_micro/conda_envs/microdl_torch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/hpc/user_apps/comp_micro/conda_envs/microdl_torch/lib/python3.6/site-packages/torch/nn/modules/conv.py", line 446, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/hpc/user_apps/comp_micro/conda_envs/microdl_torch/lib/python3.6/site-packages/torch/nn/modules/conv.py", line 443, in _conv_forward
    self.padding, self.dilation, self.groups)
RuntimeError: Expected 4-dimensional input for 4-dimensional weight[16, 1, 3, 3], but got 3-dimensional input of size [1, 2048, 2048] instead
JohannaRahm commented 1 year ago

At the end of model training an error occurs.

Settings:

Command: python /home/johanna.rahm/microDL/micro_dl/cli/torch_train_script.py --config /hpc/projects/comp_micro/projects/virtualstaining/2022_microDL_nuc_mem/configfiles/tests/pytorch_branch/mem/torch_config_2D_mem.yml

Error message:

==>] (99 training 13786/13791 [=================================================>] (99 training 13787/13791 [=================================================>] (99 training 13788/13791 [=================================================>] (99 training 13789/13791 [=================================================>] (99 training 13790/13791 [=================================================>] (99 training 13791/13791 [==================================================>] (100%)
     testing 1/32 [=>________________________________________________] (3%Traceback (most recent call last):
  File "/home/johanna.rahm/microDL/micro_dl/cli/torch_train_script.py", line 42, in <module>
    trainer.train()
  File "/home/johanna.rahm/microDL/micro_dl/torch_unet/utils/training.py", line 201, in train
    test_loss = self.run_test(i)
  File "/home/johanna.rahm/microDL/micro_dl/torch_unet/utils/training.py", line 287, in run_test
    cmap = 'gray')
  File "/hpc/user_apps/comp_micro/conda_envs/microdl_torch/lib/python3.6/site-packages/matplotlib/__init__.py", line 1447, in inner
    return func(ax, *map(sanitize_sequence, args), **kwargs)
  File "/hpc/user_apps/comp_micro/conda_envs/microdl_torch/lib/python3.6/site-packages/matplotlib/axes/_axes.py", line 5523, in imshow
    im.set_data(X)
  File "/hpc/user_apps/comp_micro/conda_envs/microdl_torch/lib/python3.6/site-packages/matplotlib/image.py", line 712, in set_data
    .format(self._A.shape))
TypeError: Invalid shape (1, 256, 256) for image data
JohannaRahm commented 1 year ago

Preprocessing for multichannel input generates a warning. No warning is observed for single channel input.

Path to preprocess config file: /hpc/projects/comp_micro/projects/virtualstaining/2022_microDL_nuc_mem/configfiles/tests/pytorch_branch/mem/errors/mutichannel_preprocess/config_preprocess_2022_03_15_nuc_mem_z25-60.yml

Command: python /home/johanna.rahm/microDL/micro_dl/cli/preprocess_script.py --config /hpc/projects/comp_micro/projects/virtualstaining/2022_microDL_nuc_mem/configfiles/tests/pytorch_branch/nuc_mem/config_preprocess_2022_03_15_nuc_mem_z25-60.yml

Warning:

tile image t000 p335 z060 c000...
tile image t000 p335 z060 c001...
tile image t000 p335 z060 c002...
/hpc/user_apps/comp_micro/conda_envs/microdl_torch/lib/python3.6/site-packages/numpy/lib/arraysetops.py:580: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison
  mask |= (ar1 == a)
Christianfoley commented 1 year ago

Preprocessing for multichannel input generates a warning. No warning is observed for single channel input.

Hmm, I've never run into this warning before. According to the warning message it seems like this issue was handled automatically. I'm not as familiar with the preprocessing code, though, @jennyfolkesson have you seen this before?

jennyfolkesson commented 1 year ago

Thanks for bringing this issue to attention @JohannaRahm . I haven't seen that warning before. It looks like numpy is complaining about comparing different data types, e.g. scalar or array vs string. Preprocessing was built to handle multiple channels as input, since it makes no distinction between inputs and targets it always processes a minimum of two channels. So I'm surprised this warning hasn't showed up before... Has the preprocessing part of the config file changed at all? If it performs as expected despite the warning we can keep it as is but pay attention when we're upgrading numpy versions.

Christianfoley commented 1 year ago

At the end of model training an error occurs

Settings:

1 channel in, 1 channel out mask = False augmentation = True for more info check out the yml files at /hpc/projects/comp_micro/projects/virtualstaining/2022_microDL_nuc_mem/configfiles/tests/pytorch_branch/mem/errors/mask_false/

@JohannaRahm I know exactly where this error comes from, but I wasn't able to reproduce this error when I train a 2D network on my side with similar configs. Perhaps the culprit is the data format from preprocessing? When you get a chance to update the permissions on the config files (and their data & metadata) that generated this error that I can debug further.

For now I have added some error-handling in this commit that makes this non-fatal, so you can continue to train, you may just be missing a figure or two in the progress summaries after each epoch.

Christianfoley commented 1 year ago

I get the same error as Johanna once I have trained a 2D model to predict nucleus from phase and run the inference.

It seems like an error related to the dimension of input tiles created during preprocessing.

Hi @Soorya19Pradeep. After a little looking, it seems like this issue is caused by not specifying a depth for the input channel in the preprocessing. If you only specify a depth for some of the channels channel (example depth: [1,1] when num_channels = 3) then the third channel, which happens in your case to be the phase input, will not be allocated a dimension. That means the data loaded in will be 4 dimensional instead of 5. The code will remove one of these 4 dimensions (what it thinks is the one 2d z slice) leaving just 3, and expect 4 dimensions still (hence the output of your error).

In your case, since you have three channels, if you specify depth = 1 for each of these 3 channels: depth: [1,1,1] the data loaded in for inference will have 5 dimensions (batch, channels, z (1), x, y) and the 2D unet will prune the z-dimension properly.

To make this a bit easier (so you don't need to rerun preprocessing), I've made it so that inference is able to use contextual information to deal with the missing dimension. The changes should be in the most recent commit.

JohannaRahm commented 1 year ago

Numpy warning

/hpc/user_apps/comp_micro/conda_envs/microdl_torch/lib/python3.6/site-packages/numpy/lib/arraysetops.py:580: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison mask |= (ar1 == a)

I have observed this warning also for 1 channel in and 1 channel out now. Agreed that this does not break any functionality and that we should keep it at the back of our heads when updating the numpy version.

The permissions are changed now and you should be able to use the files I created. Let me know if you still run into troubles.

Thanks for fixing the training error @Christianfoley I can run training without error with commit 65049c01a8ea39ad3c2444b5e494e546e593e065.

JohannaRahm commented 1 year ago

Image_dirs key error in inference

It was possible to define multiple directories as image_dirs (list of str) in the inference config. I have never used this feature and only defined one directory in the image_dirs list. Just wanted to mention this in case we are interested to keep this feature.

image_dirs: [/hpc/projects/comp_micro/projects/HEK/2022_03_15_orgs_nuc_mem_63x_04NA/all_pos_single_page/all_pos_Phase1e-3_Denconv_Nuc8e-4_Mem8e-4_pad15_bg50_registered_refmem_min25_max60]

This now gives a key error and image_dir (str) is required.

image_dir: /hpc/projects/comp_micro/projects/HEK/2022_03_15_orgs_nuc_mem_63x_04NA/all_pos_single_page/all_pos_Phase1e-3_Denconv_Nuc8e-4_Mem8e-4_pad15_bg50_registered_refmem_min25_max60

Command: python /home/johanna.rahm/microDL/micro_dl/cli/torch_inference_script.py --config /hpc/projects/comp_micro/projects/virtualstaining/2022_microDL_nuc_mem/configfiles/tests/pytorch_branch/mem_small/torch_config_2D_mem_small_image_dirs.yml

Error message:

(/hpc/user_apps/comp_micro/conda_envs/microdl_torch) [johanna.rahm@gpu-b-2 microDL]$ python /home/johanna.rahm/microDL/micro_dl/cli/torch_inference_script.py --config /hpc/projects/comp_micro/projects/virtualstaining/2022_microDL_nuc_mem/configfiles/tests/pytorch_branch/mem_small/torch_config_2D_mem_small.yml
/home/johanna.rahm/microDL/micro_dl/input/dataset.py:135: UserWarning: Warning: tf-dependent BaseDataSet to be replaced with GunPowder in 2.1.0
  warnings.warn('Warning: tf-dependent BaseDataSet to be replaced with GunPowder in 2.1.0')
Using TensorFlow backend.
/home/johanna.rahm/microDL/micro_dl/input/inference_dataset.py:14: UserWarning: InferenceDataSet class to be replaced with gunpowder in 2.1.0
  warnings.warn('InferenceDataSet class to be replaced with gunpowder in 2.1.0')
Using GPU 0 with memory fraction 0.8169310977219274.
PyTorch model load status: <All keys matched successfully>
Traceback (most recent call last):
  File "/home/johanna.rahm/microDL/micro_dl/cli/torch_inference_script.py", line 92, in <module>
    preprocess_config = preprocess_config)
  File "/home/johanna.rahm/microDL/micro_dl/inference/image_inference.py", line 106, in __init__
    self.image_dir = inference_config["image_dir"]
KeyError: 'image_dir'

Config file: /hpc/projects/comp_micro/projects/virtualstaining/2022_microDL_nuc_mem/configfiles/tests/pytorch_branch/mem_small/config_inference_2022_03_15_mem_small_image_dirs.yml

Commit: 65049c01a8ea39ad3c2444b5e494e546e593e065

JohannaRahm commented 1 year ago

I trained three different models (5 epochs, subset of data):

  1. phase in, membrane out, 2D
  2. phase in, membrane out, 2.5D
  3. phase in, nucleus and membrane out, 2D

No errors occurred during preprocessing, training and inference! There are no black and white artifacts in the inferred images.

Config file locations:

  1. /hpc/projects/comp_micro/projects/virtualstaining/2022_microDL_nuc_mem/configfiles/tests/pytorch_branch/mem_small/
  2. /hpc/projects/comp_micro/projects/virtualstaining/2022_microDL_nuc_mem/configfiles/tests/pytorch_branch/25D_small/
  3. /hpc/projects/comp_micro/projects/virtualstaining/2022_microDL_nuc_mem/configfiles/tests/pytorch_branch/nuc_mem_small/

Inference result locations:

  1. /hpc/projects/comp_micro/projects/virtualstaining/2022_microDL_nuc_mem/configfiles/tests/pytorch_branch/mem_small/training_model_2022_10_20_02_04/val/
  2. /hpc/projects/comp_micro/projects/virtualstaining/2022_microDL_nuc_mem/configfiles/tests/pytorch_branch/25D_small/training_model_2022_10_20_02_57/val/
  3. /hpc/projects/comp_micro/projects/virtualstaining/2022_microDL_nuc_mem/configfiles/tests/pytorch_branch/nuc_mem_small/training_model_2022_10_20_02_50/val/

Commit: 65049c01a8ea39ad3c2444b5e494e546e593e065

Christianfoley commented 1 year ago

It was possible to define multiple directories as image_dirs (list of str) in the inference config. I have never used this feature and only defined one directory in the image_dirs list. Just wanted to mention this in case we are interested to keep this feature.

I actually was not aware we supported this feature. I will look into how much additional work it would be to continue support for this, thanks for bringing it up!

No errors occurred during preprocessing, training and inference! There are no black and white artifacts in the inferred images.

Thank you for running these tests! The lack of speckle artifacts is good! In this case perhaps the model that expresses the artifacts is learning something specific to Soorya's A549 cell data.

Christianfoley commented 1 year ago

@Soorya19Pradeep @JohannaRahm As of the most recent commits (1, 2, 3), the pytorch dataloading pipeline now supports multiprocessing! This should dramatically speed up training as we can load samples in while the network is working on processing previous ones.

To enable this feature you can add a num_workers: 4(probably around 4, no more than 8, see readme) parameter to your torch_config.yml file that you reference in training and inference.

I have noticed a speedup of between 3 and 4x depending on the size of datasets, but there is substantial overhead for initializing each worker each epoch, so I'm sure the increase will be much greater for Johanna's large datasets. I suggest that you also change testing_stride to be the same as your save_model_stride... Maybe something like your epochs / 5.

mattersoflight commented 1 year ago

This issue is resolved now with switch to pytorch-lightning.