Project-MONAI / MONAI

AI Toolkit for Healthcare Imaging
https://monai.io/
Apache License 2.0
5.71k stars 1.04k forks source link

Fully-connected 3D CRF #315

Closed LucasFidon closed 8 months ago

LucasFidon commented 4 years ago

Is your feature request related to a problem? Please describe. 3D CRF have been widely used as post-processing in segmentation applications.

Describe the solution you'd like We have made publically available an implementation of a fully-connected 3D CRF in PyTorch: https://github.com/SamuelJoutard/Permutohedral_attention_module

This could be added directly to MONAI.

The 3D CRF itself is based on the existing implementation in NiftyNet, and it relies on an efficient implementation of gaussian filtering in cuda. Our implementation is several times faster than the NiftyNet implementation.

wyli commented 4 years ago

looks really nice, perhaps could start with the pure pytorch version? we haven't figured out how to package/distribute cuda kernels in monai. (have you tried to contribute the kernels to pytorch directly?)

LucasFidon commented 4 years ago

@wyli yes starting with the pure PyTorch version in the folder PAM sounds good. It would only require to change one of the import in the crf.py file.

As a motivation for future integration of the cuda code, I have measured that the CRF with cuda version is 5x faster than the pure PyTorch version on my machine at inference for a volume 136x136x136.

Regarding the cuda kernels integration I will let @SamuelJoutard reply as he is the expert on this.

SamuelJoutard commented 4 years ago

@wyli No I did not try to submit the kernel to Pytorch because the HashTable is not dynamically allocated in terms of memory (the kernel operates on a pre-allocated "relatively large" amount of memory) and this could be improved in my opinion.

In addition to the speed-up @LucasFidon mentioned, the Cuda version properly handles collisions in the HashTable which could marginally improve the results.

tvercaut commented 4 years ago

This issue hinges on #785

wyli commented 3 years ago

as #785 is now fixed, do you want to revisit this ticket? @LucasFidon @SamuelJoutard? the cpp/cuda codes are organised in https://github.com/Project-MONAI/MONAI/tree/master/monai/csrc

currently (as of monai v0.3) in a git-cloned MONAI codebase, running:

tvercaut commented 3 years ago

@charliebudd will be able to help. Charlie: have a look at https://github.com/Project-MONAI/MONAI/blob/master/CONTRIBUTING.md

wyli commented 3 years ago

bilateral filtering merged #1375, As discussed with @charliebudd, optional todos are:

tvercaut commented 3 years ago

@charliebudd Following our discussion earlier, it would be good to double check the status of the CRF code in MONAI. If it's finalised already, you can close this issue.

In any case, it would be good to have a demo notebook for it. Not sure if it fits in MONAI proper or in https://github.com/Project-MONAI/tutorials

For the record, we had a similar notebook in NiftyNet already that could be ported (unless you have better ideas of course): https://github.com/NifTK/NiftyNet/blob/dev/demos/crf_as_rnn/crf_as_rnn_inference_demo.ipynb

masadcv commented 3 years ago

Hey @charliebudd, I was wondering if you are looking into the tutorial mentioned by @tvercaut ? If it is still open for contribution then perhaps I can help convert the NiftyNet tutorial to use MONAI's CRF implementation.

charliebudd commented 3 years ago

I just set up a draft PR with the notebook in... https://github.com/Project-MONAI/tutorials/pull/170. The CRF is not currently in the released version of monai which the tutorials use.

wyli commented 3 years ago

feature request: CRF Post Processing as a MONAI Transform (from @masadcv #2196)

Is your feature request related to a problem? Please describe. In a number of deep learning based segmentation models, conditional random fields (CRFs) are used as a post processing step to process the output and produce segmentation maps that are more consistent with the underlying regions within an image.

Describe the solution you'd like At the moment, MONAI provides CRF layers that can enable this. It may be beneficial to have dictionary/array Transforms that utilise CRF and do post processing - such that these can be used to compose a post processing transform that can perform the CRF in the post processing step.

Describe alternatives you've considered Using CRF layer from MONAI as a separate model layer, that is attached to a model or the outputs from the model. It may be more convenient to separate this out into a Transform such that it can be quickly utilised in post processing steps.

Additional context As an example, the following is an initial prototype to give an idea of how this may be approached:

from monai.networks.blocks import CRF
from monai.transforms import Transform

class ApplyCRFPostProcd(Transform):
    def __init__(
        self,
        unary: str,
        pairwise: str,
        post_proc_label: str = 'postproc',
        iterations: int = 5, 
        bilateral_weight: float = 3.0,
        gaussian_weight: float = 1.0,
        bilateral_spatial_sigma: float = 5.0,
        bilateral_color_sigma: float = 0.5,
        gaussian_spatial_sigma: float = 5.0,
        compatibility_kernel_range: float = 1,
        device = torch.device('cpu'),
    ):
        self.unary = unary
        self.pairwise = pairwise
        self.post_proc_label = post_proc_label
        self.device = device

        self.crf_layer = CRF(
                iterations, 
                bilateral_weight,
                gaussian_weight,
                bilateral_spatial_sigma,
                bilateral_color_sigma,
                gaussian_spatial_sigma,
                compatibility_kernel_range
                )

    def __call__(self, data):
        d = dict(data)
        unary_term = d[self.unary].float().to(self.device)
        pairwise_term = d[self.pairwise].float().to(self.device)
        d[self.post_proc_label] = self.crf_layer(unary_term, pairwise_term)
        return d

Example usage of above as post processing would be:

post_transforms = [
            ApplyCRFPostProcd(unary='logits', pairwise='image', post_proc_label='pred'),
            SqueezeDimd(keys='pred', dim=0),
            ToNumpyd(keys='pred])
]

Please let me know your thoughts about this, whether it makes sense to have as a Transform? If so, I am happy to work on this.

vikashg commented 8 months ago

close merged