huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.92k stars 26.52k forks source link

Support more memory efficient processing for segmentation models. #29546

Open jveitchmichaelis opened 7 months ago

jveitchmichaelis commented 7 months ago

Feature request

This feature request significantly improves memory consumption for segmentation models, particularly when working with datasets with large numbers of instances per image.

Motivation

Most (all?) of the models for segmentation tasks in transformers rely on the label/ground truth input being a list of instance masks. For images that contain large numbers of objects (for example aerial imagery, life sciences/microscopy) this can result in huge arrays being passed around. For example a slide image containing 200 cells, each as separate instances, requires a mask input of 200xWxH. At least on my computer, trying to process such datasets means I regularly get OOMs - even with 64GB RAM - unless I take care to limit the number of instances per sample.

This issue is also relevant for torchvision's implementation of Mask-RCNN for the same reason, but I think Detectron2 (and possibly mmdet) can operate on polygons/RLE masks directly and I've not had issues training instance segmentation models from inputs with large numbers of objects. (Actually an alternative to this proposal would be to support internally encoding masks as RLE which would also significantly save on memory). My suspicion is that this hasn't been an issue because benchmark datasets like COCO have relatively few instances per image.

There are a couple of places that this situation can be improved, with significant boosts to processing speed and memory usage. Perhaps the biggest advantage is the ability to process much larger batch sizes on memory-constrained machines.

(1) The first is maybe specific to DETR.

DetrForSegmentation's processor computes bounding boxes by using a masks_to_boxes function which operates on stack of instance masks. This seems like an intentional decision, but I'm not sure why unless we can't assume that the segments_info boxes are scaled correctly. This function is expensive and is noticeably slow if you have e.g. 100 objects in an image. For object detection models, the processor simply loads the box coordinates from annotations. In the panoptic regime we'd achieve the same by querying segments_info; we can fall back to the mask processing if the bounding box info isn't provided.

This a minor fix, but for some samples it gives me an order of magnitude improvement in data-loading speed (which, without this optimisation, can be much longer than the forward/backward pass)


        # This is taken almost verbatim from the object detection processor
        if "bbox" in target['segments_info'][0]:

            boxes = [segment_info["bbox"] for segment_info in target["segments_info"]]
            boxes = np.asarray(boxes, dtype=np.float32).reshape(-1, 4)
            boxes[:, 2:] += boxes[:, :2]
            boxes[:, 0::2] = boxes[:, 0::2].clip(min=0, max=image_width)
            boxes[:, 1::2] = boxes[:, 1::2].clip(min=0, max=image_height)

            #keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
            new_target["boxes"] = masks_to_boxes(masks)
        else:
            new_target["boxes"] = masks_to_boxes(masks)

(2) The second is more significant for memory, but a more involved fix. Most of the models use the target masks to compute a mask loss of some kind. MaskFormer uses the same function. Mask2Former and OneFormer use a slightly different approach with a sampled point loss.

For DETR, bounding box comparisons are used to assign source:target predictions, and then some permutation happens such that we can pair up the relevant source predictions (one for each target), and re-order the target masks so that we can compare. For MaskFormer/Mask2Former/OneFormer, the Hungarian matching algorithm is run on the masks themselves - see a comment later.

The main issue here is not processing speed (passing around individual masks makes things simple to reason about), but the significant memory burden of passing around these massive instance arrays which get, somewhat by definition, more sparse the more objects are present. Instead, if we have access to (a) a panoptic mask as processed with rgb_to_id and (b) the segment IDs which are ordered with respect to the input bounding boxes, we can iterate over the ground truth and pick off the mask for each object.

Performance wise I think should be net zero because this masking operation is normally done as part of dataloading anyway to generate the individual instance masks. I'm sure a Numpy wizard could make the actual code more performant but here is a possible implementation that (in my brief testing) gives identical losses to the loss_masks version.

def loss_mask(self, outputs, targets, indices, num_boxes):
        """
        Compute the losses related to the masks: the focal loss and the dice loss.

        Targets dicts must contain the key "mask" containing a tensor of dim [h, w] where each pixel
        corresponds to a segment index. The target dict must also contain "segment_ids" which are used
        to extract individual objects from the mask itself.
        """
        if "pred_masks" not in outputs:
            raise KeyError("No predicted masks found in outputs")

        source_idx = self._get_source_permutation_idx(indices)
        target_idx = self._get_target_permutation_idx(indices)

        # Permute/filter outputs to one source per target
        source_masks = outputs["pred_masks"]
        source_masks = source_masks[source_idx]

        # Resize target masks to uniform shape
        # TODO use valid to mask invalid areas due to padding in loss
        masks = [t["mask"].unsqueeze(0) for t in targets]
        target_masks, _ = nested_tensor_from_tensor_list(masks).decompose()
        target_masks = target_masks.to(source_masks)

        # Upsample predictions to the target size
        source_masks = nn.functional.interpolate(
            source_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
        )

        segment_ids = [t['segment_ids'] for t in targets]

        from collections import defaultdict
        losses = defaultdict(int)

        # Calculate loss per predicted mask
        for idx, s in enumerate(source_masks):

            # Derive batch/segment (probably a better way to do this)
            batch, segment = target_idx[0][idx], target_idx[1][idx]

            # Extract mask for object
            t = (target_masks[batch] == segment_ids[batch][segment]).flatten(1).float()
            s = s.flatten().unsqueeze(0)

            losses["loss_mask"] += sigmoid_focal_loss(s, t, num_boxes)
            losses["loss_dice"] += dice_loss(s, t, num_boxes)

        return losses

The main user-facing difference here is that the preprocessor needs to provide the rest of "segments_info" in the labels. There may also need to be some logic around transformations, but in principle this should be done prior to processing/encoding? e.g. one loads the image and the annotations, performs any transformation and the dataset returns the augmented sample and takes care not to include e.g. segments that were cropped out.

For DETR, this modification is minor but it really improves memory usage by 2-3 orders of magnitude in some cases. For me it enables training with a batch size of 8-16 images instead of 1-2 and I can run with many workers without hitting OOM. It provides the benefit of (almost) constant, predictable memory consumption during dataloading because the input mask is always a fixed size.

On Mask/Mask2/OneFormer: the difference with more recent models is that matching is done on a mask-basis and not a box-basis (e.g. MaskFormerHungarianMatcher), but a similar approach could be made where we would replace this with an iteration over segment indices present in the target mask when computing the matching cost matrix.

target_mask_flat = target_mask[:, 0].flatten(1) 

we would pay a penalty in speed, because presumably everything is well-vectorised at the moment (loops bad?). However, I think having the option to pay that price instead over memory may be worth it (again - in order to generate the stack of instance masks, that masking operation has to happen somewhere else anyway).

Note that currently the matcher calculates the same costs as loss_masks in order to derive the cost matrix, but these scores are then discarded - it would make more sense to just use the source:target losses directly from the cost matrix, once the matcher has run? i.e. loss_masks should just return a sum over the winning indices in the cost matrix.

Your contribution

There are two primary contributions here:

I'm happy to PR these but would appreciate some discussion on implementation any other considerations that we'd have to make r.e. the order of dataloading and transformations.

jveitchmichaelis commented 6 months ago

Looking into this more, there is also some inconsistency around the scope of some of the image processors and a lot of duplicated processing code with slight variations. For example:

DETR - requires the use of COCO-formatted annotation files, augmentation is extremely difficult to do canonically because ideally we want to run preprocess on a batch, and the pre-processor loads the mask directly from disk. Conditional DETR, Deformable DETR, DETA, - same as above (I guess this applies to all DETR-derived models) MaskFormer, Mask2Former, OneFormer, SAM - use a more consistent interface and allows passing in segmentation_maps to preprocess

So probably add to the todo list: update DETR family models to use more consistent dataloading, on par with more recent architectures.

Most of the recent models re-use _preprocess, _preprocessimage, _preprocessmask without modification. Would it make sense to refactor all of these into a single place and subclass where necessary?

It's a tedious job, but I would suggest making a derived class from BaseImageProcessor, something like SegmentationImageProcessor and then re-use that in all the model-specific image processing classes with model-specific adjustments. I'm not sure what the style/standard is for this. There is some sense in having each model completely stand-alone (e.g. there are a lot of functions which are explicitly labeled as "copied from X), but there is also some sense in centralising some of these functions so that bugs don't get copied around, and the interfaces can be kept consistent.

EDIT: On the Hungarian Matcher, my mistake - in the paper the loss is on two different sets of points during matching + "final loss" computation, so the cost matrix can't be re-used.

NielsRogge commented 4 months ago

Pinging @qubvel here