luca-medeiros / lang-segment-anything

SAM with text prompt
Apache License 2.0
1.54k stars 168 forks source link

I am trying to use LangSam for SAM2 model. #70

Open plutus123 opened 1 month ago

plutus123 commented 1 month ago

@luca-medeiros These are the changes that I have had done I am trying to use LangSam for SAM2 model. I have had changed the code and made it compatible for SAM2 but now I am encountering issue . The changes that I have had done in lang_sam.py file are

import os
import numpy as np
import torch
from groundingdino.models import build_model
from groundingdino.util import box_ops
from groundingdino.util.inference import predict
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict
from huggingface_hub import hf_hub_download
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
import groundingdino.datasets.transforms as T

SAM_MODELS = {
    "sam2-hiera-tiny": ("sam2_hiera_t.yaml", "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt"),
    "sam2-hiera-small": ("sam2_hiera_s.yaml", "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt"),
    "sam2-hiera-base-plus": ("sam2_hiera_b+.yaml", "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt"),
    "sam2-hiera-large": ("sam2_hiera_l.yaml", "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt"),
}

CACHE_PATH = os.environ.get("TORCH_HOME", os.path.expanduser("~/.cache/torch/hub/checkpoints"))

def load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu'):
    cache_config_file = hf_hub_download(repo_id=repo_id, filename=ckpt_config_filename)
    args = SLConfig.fromfile(cache_config_file)
    model = build_model(args)
    args.device = device
    cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
    checkpoint = torch.load(cache_file, map_location='cpu')
    log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
    print(f"Model loaded from {cache_file} \n => {log}")
    model.eval()
    return model

def transform_image(image) -> torch.Tensor:
    transform = T.Compose([
        T.RandomResize([800], max_size=1333),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
    image_transformed, _ = transform(image, None)
    return image_transformed

class LangSAM:
    def __init__(self, sam_type="sam2-hiera-large", ckpt_path=None, return_prompts: bool = False):
        self.sam_type = sam_type
        self.return_prompts = return_prompts
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.build_groundingdino()
        self.build_sam(ckpt_path)

    def build_sam(self, ckpt_path):
        if self.sam_type is None or ckpt_path is None:
            if self.sam_type is None:
                print("No sam type indicated. Using sam2-hiera-large by default.")
                self.sam_type = "sam2-hiera-large"
            config_file, checkpoint_url = SAM_MODELS[self.sam_type]
            config_path = os.path.join(CACHE_PATH, config_file)
            checkpoint_path = os.path.join(CACHE_PATH, os.path.basename(checkpoint_url))
            if not os.path.exists(checkpoint_path):
                print(f"Downloading {self.sam_type} checkpoint...")
                torch.hub.download_url_to_file(checkpoint_url, checkpoint_path)
        else:
            config_path = self.sam_type
            checkpoint_path = ckpt_path

        try:
            sam2 = build_sam2(config_path, checkpoint_path, device=self.device)
            self.sam = SAM2ImagePredictor(sam2)
        except Exception as e:
            raise ValueError(f"Problem loading SAM: {str(e)}")

    def build_groundingdino(self):
        ckpt_repo_id = "ShilongLiu/GroundingDINO"
        ckpt_filename = "groundingdino_swinb_cogcoor.pth"
        ckpt_config_filename = "GroundingDINO_SwinB.cfg.py"
        self.groundingdino = load_model_hf(ckpt_repo_id, ckpt_filename, ckpt_config_filename)

    def predict_dino(self, image_pil, text_prompt, box_threshold, text_threshold):
        image_trans = transform_image(image_pil)
        boxes, logits, phrases = predict(
            model=self.groundingdino,
            image=image_trans,
            caption=text_prompt,
            box_threshold=box_threshold,
            text_threshold=text_threshold,
            device=self.device
        )
        W, H = image_pil.size
        boxes = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])
        return boxes, logits, phrases

    def predict_sam(self, image_pil, boxes):
        image_array = np.asarray(image_pil)
        self.sam.set_image(image_array)
        transformed_boxes = self.sam._transforms.transform_boxes(boxes, image_array.shape[:2])
        masks, _, _ = self.sam.predict_torch(
            point_coords=None,
            point_labels=None,
            boxes=transformed_boxes,
            multimask_output=False,
        )
        return masks

    def predict(self, image_pil, text_prompt, box_threshold=0.3, text_threshold=0.25):
        boxes, logits, phrases = self.predict_dino(image_pil, text_prompt, box_threshold, text_threshold)
        masks = torch.tensor([])
        if len(boxes) > 0:
            masks = self.predict_sam(image_pil, boxes)
        return masks, boxes, phrases, logits

And this is my LangSam Implementation:

from PIL import Image
from lang_sam import LangSAM
from lang_sam.utils import draw_image

# Specify the path to your configuration and checkpoint files
config_path = "/segment-anything-2/sam2_configs/sam2_hiera_l.yaml"
checkpoint_path = "/segment-anything-2/checkpoints/sam2_hiera_large.pt"

# Initialize LangSAM with explicit paths
model = LangSAM(sam_type=config_path, ckpt_path=checkpoint_path)

image_pil = Image.open('weld_img1.jpg').convert("RGB")
text_prompt = 'welds'
masks, boxes, labels, logits = model.predict(image_pil, text_prompt)
print(f"Masks: {masks.shape}\nBoxes: {boxes.shape}\nLabels: {labels}\nLogits: {logits.shape}")

# Visualize the results
draw_image(image_pil, masks, boxes, labels)

But still I am encountering an error:

NameError                                 Traceback (most recent call last)
[/content/drive/MyDrive/lang-segment-anything/lang_sam/lang_sam.py](https://localhost:8080/#) in build_sam(self, ckpt_path)
     95         self.sam.set_image(image_array)
---> 96         transformed_boxes = self.sam._transforms.transform_boxes(boxes, image_array.shape[:2])
     97         masks, _, _ = self.sam.predict_torch(

NameError: name 'sam_model_registry' is not defined

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
2 frames
[/content/drive/MyDrive/lang-segment-anything/lang_sam/lang_sam.py](https://localhost:8080/#) in build_sam(self, ckpt_path)
     96         transformed_boxes = self.sam._transforms.transform_boxes(boxes, image_array.shape[:2])
     97         masks, _, _ = self.sam.predict_torch(
---> 98             point_coords=None,
     99             point_labels=None,
    100             boxes=transformed_boxes,

ValueError: Problem loading SAM. Your model type: /segment-anything-2/sam2_configs/sam2_hiera_l.yaml                 should match your checkpoint path: /segment-anything-2/checkpoints/sam2_hiera_large.pt. Recommend calling LangSAM                 using matching model type AND checkpoint path

Please can anyone help me to resolve this error!!

luca-medeiros commented 1 month ago

Planning to release a new version sam2 compatible next week. Also clear all the conflict issues.

Hard to assist you since you have changed the original code and haven't shared your changes.

plutus123 commented 1 month ago

@luca-medeiros I have had edited the question now you can take a look at the changed code. Thanks!!

nandiniigarg commented 3 weeks ago

Hi @luca-medeiros, any updates on the release for Langsam with SAM2?