czg1225 / SlimSAM

[NeurIPS 2024] SlimSAM: 0.1% Data Makes Segment Anything Slim
Apache License 2.0
274 stars 17 forks source link

Low IOU when trained with self captured data. #11

Closed jarvishou829 closed 7 months ago

jarvishou829 commented 8 months ago

Thanks for your great work. But when I try to train with my own data. The IOU score seems too low. Here is the log.

CUDA visible devices: 1                                                         
CUDA Device Name: NVIDIA GeForce RTX 4090                                       
===========================Parameter Settings===========================                                                                                        
Pruning Ratio: 0.5                                                              
VIT num_heads: 12                       
norm_type: mean                                                                 
imptype: Disturb                        
global: False                                                                   
learning rate: 0.0001
global: False                                                                                                                                           [0/1867]
learning rate: 0.0001                   
a_weight: 0.5                           
round_to 12                             
TRAIN_SIZE 7825 VAL_SIZE 200 GRAD_SIZE 1000 Epochs 20                           
===========================Pruning Start===========================             
/home/user/workspace/SlimSAM-master/torch_pruning/dependency.py:639: UserWarning: Unwrapped parameters detected: ['neck.3.bias', 'neck.1.bias', 'pos_embed', '
neck.1.weight', 'neck.3.weight'].       
 Torch-Pruning will prune the last non-singleton dimension of a parameter. If you wish to customize this behavior, please provide an unwrapped_parameters argume
nt.                                     
  warnings.warn(warning_str)            
vit_b Pruning:                          
  Params: 89578240 => 45116800          
  Macs: 368858711040.0 => 185712844800.0                                        
  Output:                               
torch.Size([1, 256, 64, 64])            
torch.Size([600, 14, 14, 768])                                                  
torch.Size([12, 64, 64, 768])           
torch.Size([12, 64, 64, 3072])                                                  
torch.Size([25, 64, 64, 384])                                                   
------------------------------------------------------                          

save checkpoint                         
epoch: 0                                                                                                                                                        
IOU: 0.00037980064841486576 Best IOU 0.00037980064841486576                     
epoch: 1                                                                                                                                                        
IOU: 0.0003555969132073369 Best IOU 0.00037980064841486576                      
save checkpoint                                                                                                                                                 
epoch: 2                                                                        
IOU: 0.0004798856262954162 Best IOU 0.0004798856262954162                       
epoch: 3                                                                        
IOU: 0.00038100686134219785 Best IOU 0.0004798856262954162                      
epoch: 4                                                                        
IOU: 0.00033190775964380326 Best IOU 0.0004798856262954162                                                                                                      
epoch: 5                                                                                                                                                        
IOU: 0.00034291165492228654 Best IOU 0.0004798856262954162                      
Epoch 00007: reducing learning rate of group 0 to 5.0000e-05.                   
epoch: 6                                                                                                                                                        
IOU: 0.00033288924349753746 Best IOU 0.0004798856262954162

By the way, my mask files are transformed to COCO RLE format from the origin labeled files which only include polygons. I transformed them first to the binary masks and then to the COCO RLE format. So the transformed JSON files don't have point_coords in the dict as the SA-1B dataset. So I altered the code in the prune_distill_step1.py. Is that the key reason?

                                for example in dict_data:

                                    sub_count += 1

                                    # input_point = np.array(example['point_coords'])
                                    # input_label = np.array([1])

                                    mask = mask_utils.decode(example["segmentation"])

                                    # point_coords = transform.apply_coords(input_point, original_image_size)
                                    # coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=device)
                                    # labels_torch = torch.as_tensor(input_label, dtype=torch.int, device=device)
                                    # coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
                                    # points = (coords_torch, labels_torch)

                                    # Model inference
                                    image_embedding,_,_,_,_ = model.image_encoder(input_image)
                                    sparse_embeddings, dense_embeddings = model.prompt_encoder(
                                        points=None, # points,
                                        boxes=None,
                                        masks=None,
                                    )
czg1225 commented 8 months ago

Hi @jarvishou829 , To generate the specific mask output you desire during the inference phase, you must provide a prompt. However, in your code, both points=None and boxes=None are specified. Please ensure to include either a point prompt or a box prompt. I hope this help you.

jarvishou829 commented 8 months ago

Hi @jarvishou829 , To generate the specific mask output you desire during the inference phase, you must provide a prompt. However, in your code, both points=None and boxes=None are specified. Please ensure to include either a point prompt or a box prompt. I hope this help you.

Thanks for your quick reply, and I try to use the box prompt as you suggest. Here is my code.

for example in dict_data:

                                    sub_count += 1

                                    # input_point = np.array(example['point_coords'])
                                    # input_label = np.array([1])

                                    input_box = np.array(example['bbox'])
                                    boxes = torch.as_tensor(input_box, device=device)
                                    boxes = boxes[None, :]

                                    mask = mask_utils.decode(example["segmentation"])

                                    # point_coords = transform.apply_coords(input_point, original_image_size)
                                    # coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=device)
                                    # labels_torch = torch.as_tensor(input_label, dtype=torch.int, device=device)
                                    # coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
                                    # points = (coords_torch, labels_torch)

                                    # Model inference
                                    image_embedding,_,_,_,_ = model.image_encoder(input_image)
                                    sparse_embeddings, dense_embeddings = model.prompt_encoder(
                                        points=None, # points,
                                        boxes=boxes, # None,
                                        masks=None,
                                    )

The IoU rises from 0.0003 to 0.006, however, it seems still at a low level.

CUDA visible devices: 1                                                                                         [37/1795]
CUDA Device Name: NVIDIA GeForce RTX 4090                                                                                
===========================Parameter Settings===========================                                                 
Pruning Ratio: 0.5                                          
VIT num_heads: 12                                                                                                        
norm_type: mean                                             
imptype: Disturb                                            
global: False                                                                                                            
learning rate: 0.0001                                       
a_weight: 0.5                                                                                                            
round_to 12                                                                                                              
TRAIN_SIZE 7825 VAL_SIZE 200 GRAD_SIZE 1000 Epochs 20                                                                    
===========================Pruning Start===========================                                                      
/home/user/workspace/SlimSAM-master/torch_pruning/dependency.py:639: UserWarning: Unwrapped parameters detected: ['pos_
embed', 'neck.3.weight', 'neck.1.weight', 'neck.3.bias', 'neck.1.bias'].                                                 
 Torch-Pruning will prune the last non-singleton dimension of a parameter. If you wish to customize this behavior, please
 provide an unwrapped_parameters argument.                  
  warnings.warn(warning_str)                                
vit_b Pruning:                                                                                                           
  Params: 89578240 => 45116800                              
  Macs: 368858711040.0 => 185712844800.0                                                                                 
  Output:                                                   
torch.Size([1, 256, 64, 64])                                                                                             
torch.Size([600, 14, 14, 768])                              
torch.Size([12, 64, 64, 768])                               
torch.Size([12, 64, 64, 3072])                                                                                           
torch.Size([25, 64, 64, 384])                               
------------------------------------------------------                                                                   

save checkpoint                                             
epoch: 0                                                                                                                 
IOU: 0.006476015047946888 Best IOU 0.006476015047946888                                                                  
save checkpoint                                                                                                          
epoch: 1                                                                                                                 
IOU: 0.006600135839482798 Best IOU 0.006600135839482798                                                                  
epoch: 2                                                                                                                 
IOU: 0.0062571275655627775 Best IOU 0.006600135839482798

And my label JSON file is the following format, I think I convert it to the right COCO RLE format.

{"annotations": [
    {"bbox": [174.0, 8.0, 132.0, 352.0], 
    "area": 23005, 
    "segmentation": {
        "size": [360, 480], 
        "counts": "..."
        }
    }, 
    {"bbox": [1.0, 223.0, 98.0, 137.0], 
    "area": 7273, 
    "segmentation": {
        "size": [360, 480], 
        "counts": "..."
        }
    }, 
    ......
], 
 "image": {
    "image_id": 16, "width": 480, "height": 360, "file_name": "sa_16.jpg"}}
czg1225 commented 7 months ago

Hi @jarvishou829 , Seems like the format of your box prompt is wrong. For SAM, the box promt format should be [x, y, x+w, y+h]. However, your box format seems like [x, y, w, h].

jarvishou829 commented 7 months ago

Hi @jarvishou829 , Seems like the format of your box prompt is wrong. For SAM, the box promt format should be [x, y, x+w, y+h]. However, your box format seems like [x, y, w, h].

Thanks! That may be the problem. I transformed the binary masks to COCO RLE format using the pycocotools package and found it automatically stores the bboxs in [x y w h] format. I did not notice it. But the IoU keeps being low even I alter to use [x y x+w y+h] format bbox, I will try to finish the two-stage train and debug. Thanks again.

jarvishou829 commented 7 months ago

And I found that in the JSON files in the SA-1B dataset, the bbox format seems to be [x y w h]. Does it mean that the bbox format should be transformed if doing inference using the SA-1B dataset with the bbox prompt. Have you tried the inference process using bbox prompt? Does the IoU score look good? Here is part of the 'sa_1.json'.

{"image": 
    {"image_id": 1, "width": 1500, "height": 2060, "file_name": "sa_1.jpg"}, 
     "annotations": [
        {"bbox": [866.0, 946.0, 132.0, 199.0], 
         "area": 14773, 
         "segmentation": 
            {"size": [2060, 1500], 
             "counts": "..."}, 
         "predicted_iou": 0.9523417353630066, 
         "point_coords": [[940.9375, 1034.5625]], 
         "crop_box": [622.0, 902.0, 567.0, 707.0], 
         "id": 523353737, 
         "stability_score": 0.9742233753204346
        },
czg1225 commented 7 months ago

Hi @jarvishou829, Indeed, the bbox format of the SA-1B dataset needs to be converted during the inference phase as well. For more detailed information on conducting inference with prompts, you can refer to the inference.py file in our code repository. We have conducted experiments using 'bbox' prompts for inference, and the results are quite good: the mIoU reached approximately 0.875 for SlimSAM-77 and about 0.900 for SlimSAM-50.

czg1225 commented 7 months ago

This is the code we utilized in our bbox testing scripts. Perhaps you could attempt using the pre-trained SAM or SlimSAM models for testing to see if any problems persist.

mask = mask_utils.decode(example["segmentation"])
input_box = np.array(example['bbox'])
input_box = np.array([input_box[0],input_box[1],input_box[0]+input_box[2],input_box[1]+input_box[3]])
predictor.set_image(image)
our_masks, _, _ = predictor.predict(
                    point_coords=None,
                    point_labels=None,
                    box=input_box,
                    multimask_output=False,
                )
our_masks = our_masks[0]
iou += calculate_iou(our_masks, mask)
jarvishou829 commented 7 months ago

Hi @jarvishou829, Indeed, the bbox format of the SA-1B dataset needs to be converted during the inference phase as well. For more detailed information on conducting inference with prompts, you can refer to the inference.py file in our code repository. We have conducted experiments using 'bbox' prompts for inference, and the results are quite good: the mIoU reached approximately 0.875 for SlimSAM-77 and about 0.900 for SlimSAM-50.

Thanks for the suggestion. I try the inference with my trained model and get the IoU above 0.8. But the IoU in the training process seems weird. As the code in prune_distill_step1.py, I save and visualize the student_masks and mask which are used to calculate IoU, and find that they seem to be the mask of different things. I wonder should the IoU score between the student_masks and mask be so low (under 0.01)?

with open(annot, encoding="utf-8") as f:
                                dict_data = json.load(f)
                                dict_data = dict_data["annotations"]
                                sub_count = 0
                                sub_iou = 0
                                for example in dict_data:

                                    sub_count += 1

                                    input_box = np.array(example['bbox'])

                                    input_box[2] = input_box[2] + input_box[0]
                                    input_box[3] = input_box[3] + input_box[1]

                                    boxes = torch.as_tensor(input_box, device=device)
                                    boxes = boxes[None, :]

                                    mask = mask_utils.decode(example["segmentation"])

                                    # Model inference
                                    image_embedding,_,_,_,_ = model.image_encoder(input_image)
                                    sparse_embeddings, dense_embeddings = model.prompt_encoder(
                                        points=None, # points,
                                        boxes=boxes, # None,
                                        masks=None,
                                    )
                                    low_res_masks, iou_predictions = model.mask_decoder(
                                    image_embeddings=image_embedding,
                                    image_pe=model.prompt_encoder.get_dense_pe(),
                                    sparse_prompt_embeddings=sparse_embeddings,
                                    dense_prompt_embeddings=dense_embeddings,
                                    multimask_output=False,
                                    )

                                    student_masks = teacher_model.postprocess_masks(low_res_masks, input_size, original_image_size)
                                    student_masks = student_masks > teacher_model.mask_threshold
                                    student_masks = student_masks[0].detach().cpu().numpy()[0]
                                    np.save('log/stu_mask'+str(sub_count)+'.npy', student_masks)
                                    np.save('log/mask'+str(sub_count)+'.npy', mask)
                                    np.save('log/bbox'+str(sub_count)+'.npy', input_box)
                                    sub_iou += calculate_iou(student_masks, mask)

Here is the visualization of the npy for student_masks and mask. image image

czg1225 commented 7 months ago

Hi @jarvishou829 , SlimSAM has an almost same inference pipeline as the original SAM. You can find the bbox processing code from 146 line of SAM_predictor. Seems like your code is missing the 'self.transform.apply_boxes'.

if box is not None:
    box = self.transform.apply_boxes(box, self.original_size)
    box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
    box_torch = box_torch[None, :]
jarvishou829 commented 7 months ago

Hi @jarvishou829 , SlimSAM has an almost same inference pipeline as the original SAM. You can find the bbox processing code from 146 line of SAM_predictor. Seems like your code is missing the 'self.transform.apply_boxes'.

if box is not None:
    box = self.transform.apply_boxes(box, self.original_size)
    box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
    box_torch = box_torch[None, :]

I think that's the exact case! I forgot the transform when using box prompts. The code using box prompts should be like this in prune_distill_step1.py. Great thanks for your kind reply!

# point prompts
                                    # input_point = np.array(example['point_coords'])
                                    # input_label = np.array([1])
                                    # point_coords = transform.apply_coords(input_point, original_image_size)
                                    # coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=device)
                                    # labels_torch = torch.as_tensor(input_label, dtype=torch.int, device=device)
                                    # coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
                                    # points = (coords_torch, labels_torch)

                                    # box prompts
                                    input_box = np.array(example['bbox'])
                                    input_box = np.array([input_box[0],input_box[1],input_box[0]+input_box[2],input_box[1]+input_box[3]])
                                    input_box = transform.apply_boxes(input_box, original_image_size)
                                    boxes = torch.as_tensor(input_box, dtype=torch.float, device=device)
                                    boxes = boxes[None, :]

                                    mask = mask_utils.decode(example["segmentation"])

                                    # Model inference
                                    image_embedding,_,_,_,_ = model.image_encoder(input_image)
                                    sparse_embeddings, dense_embeddings = model.prompt_encoder(
                                        points=None, # points,
                                        boxes=boxes, # None,
                                        masks=None,
                                    )