xavibou / ovdsat

Official implementation of the paper 'Exploring Robust Features for Few-Shot Object Detection in Satellite Imagery'
Other
21 stars 2 forks source link

Any development for mask prediction? #3

Closed JoyLinWQ closed 5 months ago

JoyLinWQ commented 6 months ago

Hi @xavibou , thank you for providing the code for bboxes.

May I check if i) there are plans to implement the code for masks, and ii) when is the expected availability, based on the indicated in ovdsat/models/detector.py below:

self.classification == 'mask':
            raise NotImplementedError('Mask RPN not implemented yet. Should use SAM to generate proposals.')

Thank you :)

xavibou commented 6 months ago

Hi @JoyLinWQ, in the paper we implemented the OVDMaskClassifier to classify segmentation masks, and trained prototypes using masks of the objects extracted with SAM (we used the annotated bounding boxes as input prompts). However, the detector inference is much slower as SAM needs to be used on the fly and the classification of segmentation masks is more computationally demanding than the bounding box classifier.

However, if you want to look into this, it should be easy to implement. You only need to write a SAMRPN class that runs SAM to generate mask proposals, and initialise it in the OVDDetector initialisation method:

elif self.classification == 'mask':
            self.rpn = SAMRPN()

Then, you just need to add the following logic to the OVDDetector (detector.py) forward method:

elif self.classification == 'mask':
            proposals_scores, proposals = self.rpn(images)

where the SAMRPN() forward method runs SAM and returns the inputs to be fed to OVDMaskClassifier.

JoyLinWQ commented 5 months ago

hi @xavibou , thanks for the guidance and explanation. Apologies for the delay as i was trying to troubleshoot the error below. I've tried implementing the SAMRPN() class, and also added a MaskDataset(BaseDataset) class in general.

Encountered the following error when testing on (an even smaller subset of) your data:

Traceback (most recent call last):
  File "C:\Users\XXX\Documents\XXX\ovdsat\train.py", line 340, in <module>
    main(args)
  File "C:\Users\XXX\Documents\XXX\ovdsat\train.py", line 303, in main
    model = train(args, model, train_dataloader, val_dataloader, device)
  File "C:\Users\XXX\Documents\XXX\ovdsat\train.py", line 157, in train
    logits = model(prepare_image_for_backbone(images, args.backbone_type), masks, labels, aggregation=args.aggregation)
  File "C:\Users\XXX\anaconda3\envs\ovdsat\lib\site-packages\torch\nn\modules\module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "C:\Users\XXX\anaconda3\envs\ovdsat\lib\site-packages\torch\nn\modules\module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\Users\XXX\Documents\XXX\ovdsat\models\classifier.py", line 198, in forward
    masked_cosine_similarity = cosine_sim.unsqueeze(1) * mask_batch.unsqueeze(2)  # Shape: [B, M, N, H, W]
RuntimeError: The size of tensor a (602) must match the size of tensor b (3) at non-singleton dimension 4

In class OVDMaskClassifier() forward() method, it seems like the input "masks" shape has an additional dimension with value 1 at the first/second position instead of this shape (B, max_masks, H, W). May i know how this can be resolved please? Thank you.

Attached below is the SAMRPN() class which i placed in a file called mask_rpn.py under rpn/ folder.

import numpy as np
import torch
from torchvision.ops.boxes import batched_nms, box_area
from typing import Any, Dict, List, Optional, Tuple

from segment_anything.modeling import Sam
from segment_anything import SamPredictor

import numpy as np
import torch
from typing import Optional, Tuple
from segment_anything.modeling import Sam
from segment_anything.predictor import SamPredictor
from detectron2.structures import ImageList
from detectron2.engine.defaults import DefaultTrainer
from detectron2.checkpoint import DetectionCheckpointer

from models.rpn.box_rpn import get_box_RPN

class SAMRPN:
    def __init__(
        self,
        sam_model: Sam,
        config_file: str = 'configs/FasterRCNN_FPN_DOTA_config.yaml',
        checkpoint_file: str = 'weights/FasterRCNN_FPN_DOTA_final_model.pth'
    ):
        self.sam_predictor = SamPredictor(sam_model)
        self.cfg, self.rpn_model = get_box_RPN(config_file, checkpoint_file)

    def set_image(self, image: np.ndarray, image_format: str = "RGB"):
        self.sam_predictor.set_image(image, image_format)
        self.image = image

    def generate_box_proposals(self):
        image_tensor = torch.as_tensor(self.image, dtype=torch.float32, device=self.rpn_model.device)
        image_tensor = image_tensor.permute(2, 0, 1).unsqueeze(0)  # Convert to (B, C, H, W) format

        # Normalize images
        images = [(x - self.rpn_model.pixel_mean) / self.rpn_model.pixel_std for x in image_tensor]
        images = ImageList.from_tensors(
            images,
            self.rpn_model.backbone.size_divisibility,
            padding_constraints=self.rpn_model.backbone.padding_constraints,
        )

        features = self.rpn_model.backbone(images.tensor)

        with torch.no_grad():
            proposals, _ = self.rpn_model.proposal_generator(images, features, None)

        boxes = torch.stack([p.proposal_boxes.tensor for p in proposals])
        box_scores = torch.stack([p.objectness_logits / 10 for p in proposals])

        return boxes, box_scores

    def predict_masks(
        self,
        point_coords: Optional[np.ndarray] = None,
        point_labels: Optional[np.ndarray] = None,
        box: Optional[np.ndarray] = None,
        mask_input: Optional[np.ndarray] = None,
        multimask_output: bool = True,
        return_logits: bool = False
    ):
        return self.sam_predictor.predict(
            point_coords=point_coords,
            point_labels=point_labels,
            box=box,
            mask_input=mask_input,
            multimask_output=multimask_output,
            return_logits=return_logits,
        )

    # generate_masks_from_boxes
    def forward(
        self,
        boxes: torch.Tensor,
        box_scores: torch.Tensor,
        multimask_output: bool = True,
        return_logits: bool = False
    ):
        '''
        Generate mask proposals using the SAM model's RPN.

        Args:
            images (torch.Tensor): Input tensor with shape (B, C, H, W)
        '''
        masks, all_scores, low_res_masks = [], [], []

        for box, score in zip(boxes, box_scores):
            mask, mask_score, low_res_mask = self.sam_predictor.predict(
                box=box.cpu().numpy(),
                multimask_output=multimask_output,
                return_logits=return_logits
            )
            masks.append(mask)
            all_scores.append(mask_score)
            low_res_masks.append(low_res_mask)

        masks = np.concatenate(masks, axis=0)
        all_scores = np.concatenate(all_scores, axis=0)
        low_res_masks = np.concatenate(low_res_masks, axis=0)

        return masks, all_scores, low_res_masks