weijielyu / Gaga

Gaga: Group Any Gaussians via 3D-aware Memory Bank
MIT License
70 stars 2 forks source link

About the SAM mask Data #3

Closed Korace0v0 closed 1 month ago

Korace0v0 commented 1 month ago

Hi, thanks for your work. Can you provide the SAM mask for Replica and Scannet?

The quality of SAM will largely affect the results of 3D segmentation as GS-Grouping says. And we want to implement a fair comparation. But it seems really hard to get the matched mIoU scores.

Besides, can you also provide the evaluation scripts?

Many thanks to your help!

weijielyu commented 1 month ago

Hi, thanks for your interest!

For the SAM mask, we use the same hyperparameter setting as Gaussian Grouping. As mentioned in the paper Sec. 4.1, we use the following codes to select SAM masks based on confidence scores.

def get_processed_sam_mask(config: Dict,
                           auto_sam: SamAutomaticMaskGenerator,
                           image: np.ndarray, 
                           get_small_mask: bool) -> (torch.Tensor):
    device = auto_sam.predictor.device

    h, w = image.shape[:2]

    mask_data = auto_sam.generate(image)

    curr_id = 1
    scored_masks = None

    pred_masks = mask_data['masks'].float()  # num masks * H * W
    pred_scores = mask_data['iou_preds']  # num masks * num masks

    # select by confidence threshold
    selected_indexes = (pred_scores >= config['CONFIDENCE_THRESHOLD'])
    selected_scores = pred_scores[selected_indexes]
    selected_masks  = pred_masks[selected_indexes]
    _, m_H, m_W = selected_masks.shape
    mask_id = np.zeros((m_H, m_W), dtype=np.uint8)

    # rank
    selected_scores, ranks = torch.sort(selected_scores)
    # print("ranks", ranks)
    ranks = ranks + 1
    for index in ranks:
        mask_id[(selected_masks[index-1]==1).cpu().numpy()] = int(index)

    # compress the masks
    mask_indices = np.unique(mask_id)
    cur_idx = 1
    output_mask = np.zeros((m_H, m_W), dtype=np.uint8)
    for idx in mask_indices:
        if idx == 0:
        mask = (mask_id == idx)
        if mask.sum() > 0 and (mask.sum() / selected_masks[idx-1].sum()) > 0.1:
            output_mask[mask] = cur_idx
            cur_idx += 1

    output_mask = torch.tensor(output_mask, dtype=torch.int64, device=device)

    return output_mask

For our evaluation on the Replica and ScanNet datasets, please refer to https://github.com/weijielyu/Gaga/blob/main/eval.py

We plan to release the codes in the near future, so please stay tuned. Thanks for your patience!

Korace0v0 commented 1 month ago

Thank you! This really helps.