Project-MONAI / tutorials

MONAI Tutorials
https://monai.io/started.html
Apache License 2.0
1.86k stars 683 forks source link

How to use image volumes with no boxes in the training dataset for RetinaNet 3D? #1292

Open AceMcAwesome77 opened 1 year ago

AceMcAwesome77 commented 1 year ago

Hi, I am training a model for detecting blood clots in an abdominal artery on CT using the RetinaNet 3D model. To do this, I need to include image volumes with no clots - otherwise the model will just draw a box around that artery each time, regardless of whether it has a clot or not. Is it possible to train on image volumes that have no boxes? I tried to pass in json objects for these studies with an empty list for box coordinates, like so:

    {
        "image": "1.2.392.200036.9116.2.6.120663787.309230_neg.nii.gz",
        "box": [],
        "label": []
    },

But this throws the following error:

epoch 1/300 Traceback (most recent call last): File "detection/luna16_training.py", line 476, in main() File "detection/luna16_training.py", line 278, in main for batch_data in train_loader: File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 628, in next data = self._next_data() File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1333, in _next_data return self._process_data(data) File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1359, in _process_data data.reraise() File "/opt/conda/lib/python3.8/site-packages/torch/_utils.py", line 543, in reraise raise exception RuntimeError: Caught RuntimeError in DataLoader worker process 0. Original Traceback (most recent call last): File "/opt/conda/lib/python3.8/site-packages/monai/transforms/transform.py", line 102, in apply_transform return _apply_transform(transform, data, unpack_items) File "/opt/conda/lib/python3.8/site-packages/monai/transforms/transform.py", line 66, in _apply_transform return transform(parameters) File "/opt/conda/lib/python3.8/site-packages/monai/apps/detection/transforms/dictionary.py", line 188, in call d[key] = self.converter(d[key]) File "/opt/conda/lib/python3.8/site-packages/monai/apps/detection/transforms/array.py", line 166, in call return convert_box_to_standard_mode(boxes, mode=self.mode) File "/opt/conda/lib/python3.8/site-packages/monai/data/box_utils.py", line 576, in convert_box_to_standard_mode return convert_box_mode(boxes=boxes, src_mode=mode, dst_mode=StandardMode()) File "/opt/conda/lib/python3.8/site-packages/monai/data/box_utils.py", line 535, in convert_box_mode corners = src_boxmode.boxes_to_corners(boxes_t) File "/opt/conda/lib/python3.8/site-packages/monai/data/box_utils.py", line 308, in boxes_to_corners spatial_dims = get_spatial_dims(boxes=boxes) File "/opt/conda/lib/python3.8/site-packages/monai/data/box_utils.py", line 396, in get_spatial_dims if int(boxes.shape[1]) not in [4, 6]: IndexError: tuple index out of range

I am not surprised by the error since the code appears to assume there will always be at least 1 box. Is there any current way in the monai code to handle image volumes without boxes?

Thanks!

AceMcAwesome77 commented 1 year ago

I do have a solution for this that may be helpful to implement:

In monai\apps\detection\transforms\dictionary.py, change the RandCropBoxByPosNegLabeld.randomize function to this:

def randomize(  # type: ignore
        self,
        boxes: NdarrayOrTensor,
        image_size: Sequence[int],
        fg_indices: Optional[NdarrayOrTensor] = None,
        bg_indices: Optional[NdarrayOrTensor] = None,
        thresh_image: Optional[NdarrayOrTensor] = None,
    ) -> None:
        if np.sum(boxes.numpy()) > 0: # new code
            if fg_indices is None or bg_indices is None:
                # We don't require crop center to be within the boxes.
                # As along as the cropped patch contains a box, it is considered as a foreground patch.
                # Positions within extended_boxes are crop centers for foreground patches
                extended_boxes_np = self.generate_fg_center_boxes_np(boxes, image_size)
                mask_img = convert_box_to_mask(
                    extended_boxes_np, np.ones(extended_boxes_np.shape[0]), image_size, bg_label=0, ellipse_mask=False
                )
                mask_img = np.amax(mask_img, axis=0, keepdims=True)[0:1, ...]
                fg_indices_, bg_indices_ = map_binary_to_indices(mask_img, thresh_image, self.image_threshold)
            else:
                fg_indices_ = fg_indices
                bg_indices_ = bg_indices
        else: # new code
            fg_indices_ = [] # new code
            image_size_list = list(image_size) # new code
            voxel_count = image_size_list[0] * image_size_list[1] * image_size_list[2] # new code
            bg_indices_ = list(range(voxel_count)) # new code

        self.centers = generate_pos_neg_label_crop_centers(
            self.spatial_size,
            self.num_samples,
            self.pos_ratio,
            image_size,
            fg_indices_,
            bg_indices_,
            self.R,
            self.allow_smaller,
        )

In the same file, in the RandCropBoxByPosNegLabeld.call function, add this to the "#crop boxes and labels" function:

            # crop boxes and labels
            if np.sum(boxes.numpy()) > 0: # new code
                boxcropper = SpatialCropBox(roi_slices=crop_slices)
                results[i][self.box_keys], cropped_labels = boxcropper(boxes, labels)
                for label_key, cropped_labels_i in zip(self.label_keys, cropped_labels):
                    results[i][label_key] = cropped_labels_i

In the same file, change the ConvertBoxToStandardModed.call function to this:

    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
        d = dict(data)
        if np.sum(d['box'].numpy()) > 0: # new code
            for key in self.key_iterator(d):
                d[key] = self.converter(d[key])
                self.push_transform(d, key, extra_info={"mode": self.converter.mode})
        return d

Comment out this section of monai\apps\detection\networks\retinanet_detector:

        # 1. Check if input arguments are valid
        #if self.training:
        #    check_training_targets(input_images, targets, self.spatial_dims, self.target_label_key, self.target_box_key)
        #    self._check_detector_training_components()

And not necessary, but you can comment out this warning in monai\transforms\utils.py otherwise it will keep repeating:

        #warnings.warn(
        #    f"Num foregrounds {len(fg_indices)}, Num backgrounds {len(bg_indices)}, "
        #    f"unable to generate class balanced samples, setting `pos_ratio` to {pos_ratio}." 
        #) 

With these changes my code trained properly on a set of nifti image volumes both with and without boxes.