microsoft / torchgeo

TorchGeo: datasets, samplers, transforms, and pre-trained models for geospatial data
https://www.osgeo.org/projects/torchgeo/
MIT License
2.62k stars 324 forks source link

ObjectDetectionTask support for MSI #1156

Open adamjstewart opened 1 year ago

adamjstewart commented 1 year ago

Summary

The ObjectDetectionTask has an in_channels parameter but it isn't actually used for anything. At the moment, it seems that the trainer only supports RGB imagery. We should fix this.

Rationale

We specialize in MSI, how can we not support MSI.

Implementation

We're currently using torchvision backbones, which makes things more challenging. How hard would it be to switch to timm backbones?

Alternatives

Alternatively, we'll have to override the first channel of the torchvision backbone.

Additional information

No response

robmarkcole commented 3 months ago

The adaption for multi-channel looks straightforward - not sure about handling the pretrained weights

https://github.com/allenai/vessel-detection-sentinels/blob/main/src/models/frcnn.py

robmarkcole commented 2 weeks ago

I am taking a stab at this (offline) - appears straightforward:

    def configure_models(self) -> None:
        """Initialize the model.

        Raises:
            ValueError: If *model* or *backbone* are invalid.
        """
        backbone: str = self.hparams['backbone']
        model: str = self.hparams['model']
        weights: bool | None = self.hparams['weights']
        in_channels: int = self.hparams['in_channels']
        num_classes: int = self.hparams['num_classes']
        freeze_backbone: bool = self.hparams['freeze_backbone']

        if backbone in BACKBONE_LAT_DIM_MAP:
            kwargs = {
                'backbone_name': backbone,
                'trainable_layers': self.hparams['trainable_layers'],
            }
            if weights:
                kwargs['weights'] = BACKBONE_WEIGHT_MAP[backbone]
            else:
                kwargs['weights'] = None

            latent_dim = BACKBONE_LAT_DIM_MAP[backbone]
        else:
            raise ValueError(f"Backbone type '{backbone}' is not valid.")

        if model == 'faster-rcnn':
            model_backbone = resnet_fpn_backbone(**kwargs)

            if in_channels != 3:  # Adjust the first conv layer to match input channels
                first_conv_layer = model_backbone.body.conv1
                model_backbone.body.conv1 = torch.nn.Conv2d(
                    in_channels,
                    first_conv_layer.out_channels,
                    kernel_size=first_conv_layer.kernel_size,
                    stride=first_conv_layer.stride,
                    padding=first_conv_layer.padding,
                    bias=False
                )

            anchor_generator = AnchorGenerator(
                sizes=((32), (64), (128), (256), (512)), aspect_ratios=((0.5, 1.0, 2.0))
            )

            roi_pooler = MultiScaleRoIAlign(
                featmap_names=['0', '1', '2', '3'], output_size=7, sampling_ratio=2
            )

            if freeze_backbone:
                for param in model_backbone.parameters():
                    param.requires_grad = False

            self.model = torchvision.models.detection.FasterRCNN(
                model_backbone,
                num_classes,
                rpn_anchor_generator=anchor_generator,
                box_roi_pool=roi_pooler,
            )
        else: 
            raise ValueError(f"Model type '{model}' is not valid.")

However in use I get an error:

    139 if image.dim() != 3:
    140     raise ValueError(f"images is expected to be a list of 3d tensors of shape [C, H, W], got {image.shape}")
--> 141 image = self.normalize(image)
    142 image, target_index = self.resize(image, target_index)
    143 images[i] = image

File /home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torchvision/models/detection/transform.py:169, in GeneralizedRCNNTransform.normalize(self, image)
    167 mean = torch.as_tensor(self.image_mean, dtype=dtype, device=device)
    168 std = torch.as_tensor(self.image_std, dtype=dtype, device=device)
--> 169 return (image - mean[:, None, None]) / std[:, None, None]

RuntimeError: The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 0

Since kornia performs normalisation, why us torchvision raising this error?

robmarkcole commented 2 weeks ago

Appears torchvision is also performing norm - I get around this with a dummy:

from torchvision.models.detection.transform import GeneralizedRCNNTransform

class NoNormalizeTransform(GeneralizedRCNNTransform):
    def normalize(self, image):
        # Skip normalization, return the image as is
        return image

... add to config
            self.model.transform = NoNormalizeTransform(
                    min_size=800, 
                    max_size=1333, 
                    image_mean=[0.0, 0.0, 0.0],  # Dummy values, won't be used
                    image_std=[1.0, 1.0, 1.0]    # Dummy values, won't be used
                )
adamjstewart commented 2 weeks ago

Would need to see the full traceback and code to reproduce the bug you saw.