facebookresearch / dinov2

PyTorch code and models for the DINOv2 self-supervised learning method.
Apache License 2.0
8.3k stars 699 forks source link

How is the +ms segmentation head implemented ? #394

Open sabeaussan opened 3 months ago

sabeaussan commented 3 months ago

Hi all,

I'm trying to implement a multi-scale segmentation head as described in the paper :

"+ms: a boosted version of the linear setup. We concatenate the patch tokens of the 4 last layers, use a larger image resolution of 640, and use multiscale test-time augmentations to improve the predictions".

I get that I should use the get_intermediate_layers() method to retrieve n last layers features, but what exactly are you referring to with "multiscale test-time augmentations" ? I found this snippet of code that seems to implement this part but I can't understand what is going on in this code : `def _transform_inputs(self, inputs): """Transform inputs for decoder. Args: inputs (list[Tensor]): List of multi-level img features. Returns: Tensor: The transformed inputs """

    if self.input_transform == "resize_concat":
        # accept lists (for cls token)
        input_list = []
        for x in inputs:
            if isinstance(x, list):
                input_list.extend(x)
            else:
                input_list.append(x)
        inputs = input_list
        # an image descriptor can be a local descriptor with resolution 1x1
        for i, x in enumerate(inputs):
            if len(x.shape) == 2:
                inputs[i] = x[:, :, None, None]
        # select indices
        inputs = [inputs[i] for i in self.in_index]
        # Resizing shenanigans
        # print("before", *(x.shape for x in inputs))
        if self.resize_factors is not None:
            assert len(self.resize_factors) == len(inputs), (len(self.resize_factors), len(inputs))
            inputs = [
                resize(input=x, scale_factor=f, mode="bilinear" if f >= 1 else "area")
                for x, f in zip(inputs, self.resize_factors)
            ]
            # print("after", *(x.shape for x in inputs))
        upsampled_inputs = [
            resize(input=x, size=inputs[0].shape[2:], mode="bilinear", align_corners=self.align_corners)
            for x in inputs
        ]
        inputs = torch.cat(upsampled_inputs, dim=1)`

The features are upscaled to resize_factors and then, the next line, down scaled to original shape. What's the point of doing that ?