ByungKwanLee / Full-Segment-Anything

This is Pytorch Implementation Code for adding new features in code of Segment-Anything. Here, the features support batch-input on the full-grid prompt (automatic mask generation) with post-processing: removing duplicated or small regions and holes, under flexible input image size
MIT License
134 stars 9 forks source link

Full-Segment-Anything

This code is originated from the following Segment Model, where all of the code come from META AI Research, FAIR.

Affiliation: Meta AI Research, FAIR

Authors: Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alex Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick

Explanation: The Segment Anything Model (SAM) produces high quality object masks from input prompts such as points or boxes, and it can be used to generate masks for all objects in an image. It has been trained on a dataset of 11 million images and 1.1 billion masks, and has strong zero-shot performance on a variety of segmentation tasks.


Why is Full-Segment-Anything needed?

Segment-Anything code has the following critical issues for doing further research.

Therefore, Full-Segment-Anything addresses the above issues:

(Not did we re-train, but we modified in the code-level)


Version Update


Visualization of Full-Segment-Anything

Figure 1. Full-Segment-Anything on Image Resolution *128*
Figure 2. Full-Segment-Anything on Image Resolution *256*
Figure 3. Full-Segment-Anything on Image Resolution *512*
Figure 4. Full-Segment-Anything on Image Resolution *1024*
## How to use Full-Segment-Anything? In example.py, there is part of Example 6. You can consider this part and modify it to fit your individual purpose. If you want to the changed part in record, compared by origincal SAM code, you can search keyword of "by LBK EDIT" or "LBK", of which position represents the code-changed position in detail. (Example 1-5 conduct trial and errors to investigate what the problems of oiriginal SAM code are.) ```python """ Example 6: [LBK SAM] Batched Inputs -> **Full Grid Prompts** -> Multiple Mask Generation with filtering small and dulicated regions or holes [Very Hard] """ import numpy as np from PIL import Image import torch import torchvision from mask_generator import SamMaskGenerator import matplotlib.pyplot as plt from utils.utils import show_mask, show_points, show_lbk_masks from build_sam import sam_model_registry import os; os.environ["CUDA_VISIBLE_DEVICES"]="0" # img resolution img_resolution = 1024 # Select Proper SAM Size you want sam = sam_model_registry['vit_h'](checkpoint='ckpt/sam_vit_h_4b8939.pth', custom_img_size=img_resolution).cuda() # SAM ViT-H # sam = sam_model_registry['vit_l'](checkpoint='ckpt/sam_vit_l_0b3195.pth', custom_img_size=img_resolution).cuda() # SAM ViT-L # sam = sam_model_registry['vit_b'](checkpoint='ckpt/sam_vit_b_01ec64.pth', custom_img_size=img_resolution).cuda() # SAM ViT-B # sam = sam_model_registry['vit_t'](checkpoint='ckpt/mobile_sam.pt', custom_img_size=img_resolution).cuda() # Mobile-SAM # prompt from utils.amg import build_all_layer_point_grids input_point = torch.as_tensor(build_all_layer_point_grids(16, 0, 1)[0] * img_resolution, dtype=torch.int64).cuda() input_label = torch.tensor([1 for _ in range(input_point.shape[0])]).cuda() def prepare_image(image, img_resolution=img_resolution): trans = torchvision.transforms.Compose([torchvision.transforms.Resize((img_resolution, img_resolution))]) image = torch.as_tensor(image).cuda() return trans(image.permute(2, 0, 1)) # image upload img1 = np.array(Image.open("figure/sam1.png"))[...,:3] img2 = np.array(Image.open("figure/sam2.png"))[...,:3] img3 = np.array(Image.open("figure/sam3.png"))[...,:3] img4 = np.array(Image.open("figure/sam4.png"))[...,:3] img1_tensor = prepare_image(img1) img2_tensor = prepare_image(img2) img3_tensor = prepare_image(img3) img4_tensor = prepare_image(img4) plt.figure(figsize=(5,5)) plt.imshow(img1_tensor.permute(1,2,0).cpu().numpy()) plt.axis('on') plt.show() plt.figure(figsize=(5,5)) plt.imshow(img2_tensor.permute(1,2,0).cpu().numpy()) plt.axis('on') plt.show() plt.figure(figsize=(5,5)) plt.imshow(img3_tensor.permute(1,2,0).cpu().numpy()) plt.axis('on') plt.show() plt.figure(figsize=(5,5)) plt.imshow(img4_tensor.permute(1,2,0).cpu().numpy()) plt.axis('on') plt.show() # batchify batched_input = [ { 'image': x, 'point_coords': input_point, 'point_labels': input_label, 'original_size': x.shape[1:] } for x in [img1_tensor, img2_tensor, img3_tensor, img4_tensor] ] # LBK propagation refined_masks = sam.individual_forward(batched_input, multimask_output=True) # image mask generation visualization plt.figure(figsize=(5,5)) plt.imshow(img1_tensor.permute(1,2,0).cpu().numpy()) show_lbk_masks(refined_masks[0].cpu().numpy(), plt) show_points(input_point.cpu().numpy(), input_label.cpu().numpy(), plt.gca()) plt.title(f"[Full Grid] LBK Refined Mask", fontsize=18) plt.axis('on') plt.show() plt.figure(figsize=(5,5)) plt.imshow(img2_tensor.permute(1,2,0).cpu().numpy()) show_lbk_masks(refined_masks[1].cpu().numpy(), plt) show_points(input_point.cpu().numpy(), input_label.cpu().numpy(), plt.gca()) plt.title(f"[Full Grid] LBK Refined Mask", fontsize=18) plt.axis('on') plt.show() plt.figure(figsize=(5,5)) plt.imshow(img3_tensor.permute(1,2,0).cpu().numpy()) show_lbk_masks(refined_masks[2].cpu().numpy(), plt) show_points(input_point.cpu().numpy(), input_label.cpu().numpy(), plt.gca()) plt.title(f"[Full Grid] LBK Refined Mask", fontsize=18) plt.axis('on') plt.show() plt.figure(figsize=(5,5)) plt.imshow(img4_tensor.permute(1,2,0).cpu().numpy()) show_lbk_masks(refined_masks[3].cpu().numpy(), plt) show_points(input_point.cpu().numpy(), input_label.cpu().numpy(), plt.gca()) plt.title(f"[Full Grid] LBK Refined Mask", fontsize=18) plt.axis('on') plt.show() ```