czg1225 / SlimSAM

[NeurIPS 2024] SlimSAM: 0.1% Data Makes Segment Anything Slim
Apache License 2.0
274 stars 17 forks source link

evaluate algorithms such as SAM, EfficientSAM, etc. #15

Closed ranpin closed 5 months ago

ranpin commented 6 months ago

Hi, I found your work,it's great.

I am a newbie having some difficulties in reproducing your experiment and would like to seek your help.

I found that you did a comparison experiment in your paper, using the SA-1B dataset, and I successfully reproduced the training and validation results of SlimSAM in it. image

However, since you didn't release the validation scripts for the time being, I had to follow your experimental setup(https://github.com/czg1225/SlimSAM/issues/5#issuecomment-1875436503) and write my own scripts to validate FastSAM, MobileSAM, EfficientSAM and other algorithms on SA-1B, but I couldn't reproduce the results listed in your table very well, especially the EfficientSAM(only 28%) and EdgeSAM algorithms.

I wonder if it would be convenient for you to put out the validation scripts for these comparison algorithms for learning purposes, I would appreciate it.Thank you very much!

czg1225 commented 6 months ago

Hi @ranpin , This is a simple script for original SAM validation on SA-1B. Evaluating other models can follow this similar procedure.

#teacher model
teacher_model_type = 'vit_h'
checkpoint = 'checkpoints/sam_vit_h_qkv.pth'
teacher_model = sam_model_registry[teacher_model_type](checkpoint=checkpoint)
teacher_model.to(device)
teacher_model.eval()

predictor = SamPredictor(teacher_model)

with torch.no_grad():
    iou = 0
    test_iter = iter(test_loader)
    for j in range(len(test_iter)):
        batch = next(test_iter)
        input_image = batch["input_image"].to(device)

        id = batch["id"]
        annot = batch["annot"][0]
        path = id[0]
        print(j,path,annot)

        image = cv2.imread(path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        with open(annot, encoding="utf-8") as f:
            dict_data = json.load(f)
            dict_data = dict_data["annotations"]
            sub_count = 0
            sub_iou = 0
            for example in dict_data:
                sub_count += 1

                input_point = np.array(example['point_coords'])
                input_label = np.array([1])

                mask = mask_utils.decode(example["segmentation"])

                input_box = np.array(example['bbox'])
                input_box = np.array([input_box[0],input_box[1],input_box[0]+input_box[2],input_box[1]+input_box[3]])

                predictor.set_image(image)
                teacher_masks, _, _ = predictor.predict(
                                    point_coords=input_point,
                                    point_labels=input_label,
                                    box=None,
                                    multimask_output=False,
                                )
                # print(mask.shape)
                # print(teacher_masks.shape)
                teacher_masks = teacher_masks[0]

                sub_iou += calculate_iou(teacher_masks, mask)

        sub_iou = sub_iou/sub_count
        print(sub_iou)
        iou += sub_iou

    iou = iou/len(test_iter)
    print("IoU", iou)
ranpin commented 5 months ago

Hi @ranpin , This is a simple script for original SAM validation on SA-1B. Evaluating other models can follow this similar procedure.

#teacher model
teacher_model_type = 'vit_h'
checkpoint = 'checkpoints/sam_vit_h_qkv.pth'
teacher_model = sam_model_registry[teacher_model_type](checkpoint=checkpoint)
teacher_model.to(device)
teacher_model.eval()

predictor = SamPredictor(teacher_model)

with torch.no_grad():
    iou = 0
    test_iter = iter(test_loader)
    for j in range(len(test_iter)):
        batch = next(test_iter)
        input_image = batch["input_image"].to(device)

        id = batch["id"]
        annot = batch["annot"][0]
        path = id[0]
        print(j,path,annot)

        image = cv2.imread(path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        with open(annot, encoding="utf-8") as f:
            dict_data = json.load(f)
            dict_data = dict_data["annotations"]
            sub_count = 0
            sub_iou = 0
            for example in dict_data:
                sub_count += 1

                input_point = np.array(example['point_coords'])
                input_label = np.array([1])

                mask = mask_utils.decode(example["segmentation"])

                input_box = np.array(example['bbox'])
                input_box = np.array([input_box[0],input_box[1],input_box[0]+input_box[2],input_box[1]+input_box[3]])

                predictor.set_image(image)
                teacher_masks, _, _ = predictor.predict(
                                    point_coords=input_point,
                                    point_labels=input_label,
                                    box=None,
                                    multimask_output=False,
                                )
                # print(mask.shape)
                # print(teacher_masks.shape)
                teacher_masks = teacher_masks[0]

                sub_iou += calculate_iou(teacher_masks, mask)

        sub_iou = sub_iou/sub_count
        print(sub_iou)
        iou += sub_iou

    iou = iou/len(test_iter)
    print("IoU", iou)

thanks,the problem is solved.