fepegar / torchio

Medical imaging toolkit for deep learning
https://torchio.org
Apache License 2.0
2.07k stars 240 forks source link

Parallel histogram_standalization.train #970

Open hsyang1222 opened 2 years ago

hsyang1222 commented 2 years ago

🚀 Feature parallel version torchio.transforms.preprocessing.intensity.histogram_standardization.train

Motivation Currently, this method is driven using single thread. The environment in which the deep learning model is learned using torchio is likely to be a server environment with multiple cpu cores, so it can be processed in parallel to increase efficiency.

Code I wrote the code as below using multiprocessing.pool. I think this code is useful.

import torchio
import tqdm
import numpy as np
from pathlib import Path
from typing import Callable
from typing import Dict
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from torchio.typing import TypePath

from torchio.transforms.preprocessing.intensity import histogram_standardization

import multiprocessing

DEFAULT_CUTOFF = 0.01, 0.99
STANDARD_RANGE = 0, 100
TypeLandmarks = Union[TypePath, Dict[str, Union[TypePath, np.ndarray]]]

def train(
        images_paths: Sequence[TypePath],
        cutoff: Optional[Tuple[float, float]] = None,
        mask_path: Optional[Union[Sequence[TypePath], TypePath]] = None,
        masking_function: Optional[Callable] = None,
        output_path: Optional[TypePath] = None,
        num_workers: int = 32
) -> np.ndarray:
    is_masks_list = isinstance(mask_path, Sequence)
    if is_masks_list and len(mask_path) != len(images_paths):  # type: ignore[arg-type]  # noqa: E501
        message = (
            f'Different number of images ({len(images_paths)})'  # type: ignore[arg-type]  # noqa: E501
            f' and mask ({len(mask_path)}) paths found'  # type: ignore[arg-type]  # noqa: E501
        )
        raise ValueError(message)
    quantiles_cutoff = DEFAULT_CUTOFF if cutoff is None else cutoff
    percentiles_cutoff = 100 * np.array(quantiles_cutoff)
    percentiles_database = []
    a, b = percentiles_cutoff  # for mypy
    percentiles = histogram_standardization._get_percentiles((a, b))

    mask_path_list = [None] * len(images_paths)
    masking_function_list = [None] * len(images_paths)

    if masking_function is not None:
        masking_function_list = [masking_function] * len(images_paths)
    else:
        if is_masks_list:
            mask_path_list = mask_path
        else:
            mask_path_list = [mask_path] * len(images_paths)

    # At least of the of masking_function or mask_path_list is None

    percentiles_list = [percentiles] * len(images_paths)

    pool = multiprocessing.Pool(num_workers)
    with tqdm.tqdm(total=len(images_paths), desc="make histogram") as pbar:
        args_ziped = zip(images_paths, masking_function_list, mask_path_list, percentiles_list)

        for percentile_values in pool.imap_unordered(img_to_percentiles_value, args_ziped):
            percentiles_database.append(percentile_values)
            pbar.update()

    percentiles_database_array = np.vstack(percentiles_database)
    mapping = histogram_standardization._get_average_mapping(percentiles_database_array)

    if output_path is not None:
        output_path = Path(output_path).expanduser()
        extension = output_path.suffix
        if extension == '.txt':
            modality = 'image'
            text = f'{modality} {" ".join(map(str, mapping))}'
            output_path.write_text(text)
        elif extension == '.npy':
            np.save(output_path, mapping)
    return mapping

def img_to_percentiles_value(args):
    image_file_path, masking_function, mask_path, percentiles = args
    tensor, _ = histogram_standardization.read_image(image_file_path)

    if masking_function is not None:
        mask = masking_function(tensor)
    else:
        if mask_path is None:
            mask = np.ones_like(tensor, dtype=bool)
        else:
            path = mask_path  # type: ignore[assignment]
            mask, _ = histogram_standardization.read_image(path)
            mask = mask.numpy() > 0
    array = tensor.numpy()
    percentile_values = np.percentile(array[mask], percentiles)
    return percentile_values
fepegar commented 2 years ago

Hi, @hsyang1222. If you have tried this successfully, feel free to open a pull request with your changes.