Closed JoyLinWQ closed 5 months ago
Hi @JoyLinWQ, in the paper we implemented the OVDMaskClassifier to classify segmentation masks, and trained prototypes using masks of the objects extracted with SAM (we used the annotated bounding boxes as input prompts). However, the detector inference is much slower as SAM needs to be used on the fly and the classification of segmentation masks is more computationally demanding than the bounding box classifier.
However, if you want to look into this, it should be easy to implement. You only need to write a SAMRPN class that runs SAM to generate mask proposals, and initialise it in the OVDDetector initialisation method:
elif self.classification == 'mask':
self.rpn = SAMRPN()
Then, you just need to add the following logic to the OVDDetector (detector.py) forward method:
elif self.classification == 'mask':
proposals_scores, proposals = self.rpn(images)
where the SAMRPN() forward method runs SAM and returns the inputs to be fed to OVDMaskClassifier.
hi @xavibou , thanks for the guidance and explanation. Apologies for the delay as i was trying to troubleshoot the error below.
I've tried implementing the SAMRPN()
class, and also added a MaskDataset(BaseDataset)
class in general.
Encountered the following error when testing on (an even smaller subset of) your data:
Traceback (most recent call last):
File "C:\Users\XXX\Documents\XXX\ovdsat\train.py", line 340, in <module>
main(args)
File "C:\Users\XXX\Documents\XXX\ovdsat\train.py", line 303, in main
model = train(args, model, train_dataloader, val_dataloader, device)
File "C:\Users\XXX\Documents\XXX\ovdsat\train.py", line 157, in train
logits = model(prepare_image_for_backbone(images, args.backbone_type), masks, labels, aggregation=args.aggregation)
File "C:\Users\XXX\anaconda3\envs\ovdsat\lib\site-packages\torch\nn\modules\module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "C:\Users\XXX\anaconda3\envs\ovdsat\lib\site-packages\torch\nn\modules\module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "C:\Users\XXX\Documents\XXX\ovdsat\models\classifier.py", line 198, in forward
masked_cosine_similarity = cosine_sim.unsqueeze(1) * mask_batch.unsqueeze(2) # Shape: [B, M, N, H, W]
RuntimeError: The size of tensor a (602) must match the size of tensor b (3) at non-singleton dimension 4
In class OVDMaskClassifier() forward()
method, it seems like the input "masks" shape has an additional dimension with value 1 at the first/second position instead of this shape (B, max_masks, H, W). May i know how this can be resolved please? Thank you.
Attached below is the SAMRPN()
class which i placed in a file called mask_rpn.py
under rpn/
folder.
import numpy as np
import torch
from torchvision.ops.boxes import batched_nms, box_area
from typing import Any, Dict, List, Optional, Tuple
from segment_anything.modeling import Sam
from segment_anything import SamPredictor
import numpy as np
import torch
from typing import Optional, Tuple
from segment_anything.modeling import Sam
from segment_anything.predictor import SamPredictor
from detectron2.structures import ImageList
from detectron2.engine.defaults import DefaultTrainer
from detectron2.checkpoint import DetectionCheckpointer
from models.rpn.box_rpn import get_box_RPN
class SAMRPN:
def __init__(
self,
sam_model: Sam,
config_file: str = 'configs/FasterRCNN_FPN_DOTA_config.yaml',
checkpoint_file: str = 'weights/FasterRCNN_FPN_DOTA_final_model.pth'
):
self.sam_predictor = SamPredictor(sam_model)
self.cfg, self.rpn_model = get_box_RPN(config_file, checkpoint_file)
def set_image(self, image: np.ndarray, image_format: str = "RGB"):
self.sam_predictor.set_image(image, image_format)
self.image = image
def generate_box_proposals(self):
image_tensor = torch.as_tensor(self.image, dtype=torch.float32, device=self.rpn_model.device)
image_tensor = image_tensor.permute(2, 0, 1).unsqueeze(0) # Convert to (B, C, H, W) format
# Normalize images
images = [(x - self.rpn_model.pixel_mean) / self.rpn_model.pixel_std for x in image_tensor]
images = ImageList.from_tensors(
images,
self.rpn_model.backbone.size_divisibility,
padding_constraints=self.rpn_model.backbone.padding_constraints,
)
features = self.rpn_model.backbone(images.tensor)
with torch.no_grad():
proposals, _ = self.rpn_model.proposal_generator(images, features, None)
boxes = torch.stack([p.proposal_boxes.tensor for p in proposals])
box_scores = torch.stack([p.objectness_logits / 10 for p in proposals])
return boxes, box_scores
def predict_masks(
self,
point_coords: Optional[np.ndarray] = None,
point_labels: Optional[np.ndarray] = None,
box: Optional[np.ndarray] = None,
mask_input: Optional[np.ndarray] = None,
multimask_output: bool = True,
return_logits: bool = False
):
return self.sam_predictor.predict(
point_coords=point_coords,
point_labels=point_labels,
box=box,
mask_input=mask_input,
multimask_output=multimask_output,
return_logits=return_logits,
)
# generate_masks_from_boxes
def forward(
self,
boxes: torch.Tensor,
box_scores: torch.Tensor,
multimask_output: bool = True,
return_logits: bool = False
):
'''
Generate mask proposals using the SAM model's RPN.
Args:
images (torch.Tensor): Input tensor with shape (B, C, H, W)
'''
masks, all_scores, low_res_masks = [], [], []
for box, score in zip(boxes, box_scores):
mask, mask_score, low_res_mask = self.sam_predictor.predict(
box=box.cpu().numpy(),
multimask_output=multimask_output,
return_logits=return_logits
)
masks.append(mask)
all_scores.append(mask_score)
low_res_masks.append(low_res_mask)
masks = np.concatenate(masks, axis=0)
all_scores = np.concatenate(all_scores, axis=0)
low_res_masks = np.concatenate(low_res_masks, axis=0)
return masks, all_scores, low_res_masks
Hi @xavibou , thank you for providing the code for bboxes.
May I check if i) there are plans to implement the code for masks, and ii) when is the expected availability, based on the indicated in
ovdsat/models/detector.py
below:Thank you :)