kreshuklab / plant-seg

A tool for cell instance aware segmentation in densely packed 3D volumetric images
https://kreshuklab.github.io/plant-seg/
MIT License
88 stars 31 forks source link

`patch_halo` seems to be ignored #205

Closed qin-yu closed 5 months ago

qin-yu commented 6 months ago

I am fixing this right now.

My version: v1.6.6

My workflow: CLI

My config:

cnn_prediction:
  # enable/disable UNet prediction
  state: True
  # key for H5 or ZARR, can be set to null if only one key exists in each file; null is recommended if the previous steps has state True
  key: Null
  # channel to use if input image has shape CZYX or CYX, otherwise set to null; null is recommended if the previous steps has state True
  channel: Null
  # Trained model name, more info on available models and custom models in the README
  model_name: 'PlantSeg_3Dnuc_platinum'
  # If a CUDA capable gpu is available and corrected setup use "cuda", if not you can use "cpu" for cpu only inference (slower)
  device: 'cuda'
  # (int or tuple) padding to be removed from each axis in a given patch in order to avoid checkerboard artifacts
  patch_halo: [32, 32, 32]
  # how many subprocesses to use for data loading
  num_workers: 8
  # patch size given to the network (adapt to fit in your GPU mem)
  patch: [45, 666, 516]
  # stride between patches will be computed as `stride_ratio * patch`
  # recommended values are in range `[0.5, 0.75]` to make sure the patches have enough overlap to get smooth prediction maps
  stride_ratio: 0.3
  # If "True" forces downloading networks from the online repos
  model_update: False

Workflow output:

algorithm: UnetPredictions
file_suffix: _predictions
h5_output_key: predictions
input_channel: null
input_key: null
input_paths:
- /g/kreshuk/yu/Datasets/EMBL/CTischer2024Signalling/nuclei/PreProcessing/cell_t000_nuclei.h5
...
- /g/kreshuk/yu/Datasets/EMBL/CTischer2024Signalling/nuclei/PreProcessing/cell_t009_nuclei.h5
input_type: data_float32
model_name: PlantSeg_3Dnuc_platinum
out_ext: .h5
output_type: data_float32
patch:
- 45
- 300
- 300
predictor: !!python/object:plantseg.predictions.functional.array_predictor.ArrayPredictor
  batch_size: 2
  device: cuda
  disable_tqdm: false
  is_embedding: false
  model: !!python/object:plantseg.training.model.UNet3D
    ...
  out_channels: 2
  patch_halo:
  - 2
  - 4
  - 4
  verbose_logging: false
save_directory: /g/kreshuk/yu/Datasets/EMBL/CTischer2024Signalling/nuclei/PreProcessing/PlantSeg_3Dnuc_platinum
save_raw: false
state: true
stride_ratio: 0.3
qin-yu commented 6 months ago

In plantseg/predictions/functional/predictions.py's unet_predictions():

    patch_halo = get_patch_halo(model_name)
    predictor = ArrayPredictor(model=model, in_channels=model_config['in_channels'],
                               out_channels=model_config['out_channels'], device=device, patch=patch,
                               patch_halo=patch_halo, single_batch_mode=single_batch_mode, headless=False,
                               verbose_logging=False, disable_tqdm=disable_tqdm)

In plantseg/predictions/predict.py's UnetPredictions:

        patch_halo = get_patch_halo(model_name)
        is_embedding = not model_config.get('is_segmentation', True)
        self.predictor = ArrayPredictor(model=model, in_channels=model_config['in_channels'],
                                        out_channels=model_config['out_channels'], device=device, patch=self.patch,
                                        patch_halo=patch_halo, single_batch_mode=False, headless=True,
                                        is_embedding=is_embedding)

But:

def get_patch_halo(model_name):
    predict_template = get_predict_template()
    patch_halo = predict_template['predictor']['patch_halo']

    config_train = get_train_config(model_name)
    if config_train["model"]["name"] == "UNet2D":
        patch_halo[0] = 0

    return patch_halo

So I guess the user config is not used. Should be an easy fix.

qin-yu commented 6 months ago

Hey @wolny I guess a halo of [2, 4, 4] is fine for all purposes? Is it why we currently have fixed halo for PlantSeg? Or do we actually want to let users to set it?

qin-yu commented 6 months ago

Somehow according to this paper, halo matters: Exact Tile-Based Segmentation Inference for Images Larger than GPU Memory. But empirically changing halo wouldn't help.

wolny commented 5 months ago

hey @qin-yu, yes at some point I've decided to simplify the process and use a fixed 'patch_halo' during inference instead of letting the user to specify it. If there are still references to 'patch_halo' in the config, I'd just remove them since it's not used

qin-yu commented 4 months ago

With the latest implementation of the halo, its shape has become a significant factor and can no longer be simplified. Previously, the irrelevance of the halo's size was a key indicator of underlying issues.

I've implemented some changes in https://github.com/wolny/pytorch-3dunet/pull/113 for pytorch-3dunet. After receiving a review from @wolny, I plan to submit a pull request for further updates in plantseg.

qin-yu commented 4 months ago