Project-MONAI / MONAI

AI Toolkit for Healthcare Imaging
https://monai.io/
Apache License 2.0
5.84k stars 1.08k forks source link

Random cropping with scaling option for super-resolution data augmentation #4624

Open masadcv opened 2 years ago

masadcv commented 2 years ago

Is your feature request related to a problem? Please describe. In super-resolution networks, a low resolution input is upsampled using a neural networks often with an integer factor (e.g. x2, x3, x4 etc). When performing data augmentation, a good approach is to use random cropping with fixed size, especially if input images are bigger than what can fit in memory for network activations.

Describe the solution you'd like To address the data augmentation issue of random cropping, a possible solution could be to get a random crop window in low resolution image and apply the same (but scaled up) window to target ground truth image.

For example, in a x4 upsampling network cropping a 56 x 56 window in low res input will correspond to 224 x 224 window cropping in target ground truth

Describe alternatives you've considered Writing my own MONAI transforms, or manually cropping data

Additional context This may be useful for super-resolution, upsampling or demosaicing networks that typically take low res input and up resolve it

Apologies if this is addressed by an existing transform, in which case can someone guide me on how I can achieve the above with an existing transform...

wyli commented 2 years ago

thanks for the request, if I understand this correctly, it is similar to: https://github.com/Project-MONAI/MONAI/issues/4491, which mainly requires defining cropping locations and sizes using the actual physical unit instead of number of pixels (essentially 56x56 pixels window at resolution 4mmx4mm is equivalent of 224x224 window at 1mmx1mm). please help clarify @masadcv and we can prioritise it.

masadcv commented 2 years ago

Hi @wyli , Yes, in my case it is a special case of that issue - where I am interested in cropping different sized data only.

My not so polished solution looks like this at the moment:

class RandCropScaled(Transform):
    def __init__(
        self,
        lowres,
        hires_list,
        roi_size = [56, 56],
        hires_scale = 4, 
        allow_missing_keys=False
    ) -> None:
        super().__init__()
        self.lowres = lowres
        self.hires_list = hires_list
        self.roi_size = roi_size
        self.hires_scale = hires_scale
        self.allow_missing_keys = allow_missing_keys

    def _fetch_data(self, data, key):
        if key not in data.keys():
            raise ValueError(f"Key {key} not found, present keys {data.keys()}")
        return data[key].copy()

    def _calculate_random_crop_loc(self, data_np):
        x = random.randint(0, data_np.shape[2] - self.roi_size[1])
        y = random.randint(0, data_np.shape[1] - self.roi_size[0])
        # print("{}, {}".format(y, x))    
        return y, x

    def _crop_array(self, data_np, yx, roi_size):
        return data_np[..., yx[0]: yx[0] + roi_size[0], yx[1]: yx[1] + roi_size[1]]

    def __call__(self, data):
        d = dict(data)

        lowres = self._fetch_data(data, self.lowres)
        yloc_lowres, xloc_lowres = self._calculate_random_crop_loc(lowres)
        yloc_hires, xloc_hires = yloc_lowres * self.hires_scale, xloc_lowres * self.hires_scale

        hires_roi_size = [x * self.hires_scale for x in self.roi_size]

        lowres = self._crop_array(lowres, [yloc_lowres, xloc_lowres], self.roi_size)
        d[self.lowres] = lowres

        for nk in self.hires_list:
            if self.allow_missing_keys:
                if nk in d.keys():
                    d[nk] = self._crop_array(d[nk], [yloc_hires, xloc_hires], hires_roi_size)
            else:
                d[nk] = self._crop_array(d[nk], [yloc_hires, xloc_hires], hires_roi_size)

        return d
wyli commented 2 years ago

thank you @masadcv, the prototype is helpful.

masadcv commented 2 years ago

Hi @wyli , Just wanted to check if anyone is working on this? If not I will like to contribute with this transform...

wyli commented 2 years ago

not at the moment @masadcv. but as we have the MetaTensor implementation on dev, I hope it could be a universal solution of specifying roi_sizes in terms of physical units instead of number of voxels (for example the hires_scale should be computed from MetaTensor's pixdim property). would you still be interested/have the bandwidth to contribute?