SHI-Labs / Matting-Anything

Matting Anything Model (MAM), an efficient and versatile framework for estimating the alpha matte of any instance in an image with flexible and interactive visual or linguistic user prompt guidance.
https://arxiv.org/abs/2306.05399
MIT License
615 stars 49 forks source link

Inference on own image trimap pair #21

Open heorhiikalaichev opened 1 year ago

heorhiikalaichev commented 1 year ago

Hi!

Thank's for you work! Could you share script which will process input of single image and trimap?

antithing commented 1 year ago

+1

Inkyl commented 2 months ago

Hi!

Thank's for you work! Could you share script which will process input of single image and trimap?

I have the same question. Do you know how to do this?

PeterVennerstrom commented 1 week ago
from typing import Optional

import cv2
import numpy as np
import toml
import torch
from segment_anything.utils.transforms import ResizeLongestSide
from torch.nn import functional as F

import networks
import utils
from utils import CONFIG

def prep_sample(image_path: str, bbox: list, side: Optional[int] = 1024) -> dict:
    """Loads image and bounding box with transformations

    bbox (list(int)): [xmin, ymin, xmax, ymax]
    side (int): default setting for SAM inference space
    """
    bbox = np.array(bbox)

    transform = ResizeLongestSide(side)

    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    original_size = image.shape[:2]

    image = transform.apply_image(image)
    image = torch.as_tensor(image).cuda()
    image = image.permute(2, 0, 1).contiguous()
    bbox = transform.apply_boxes(bbox, original_size)
    bbox = torch.as_tensor(bbox, dtype=torch.float).cuda()

    pixel_mean = torch.tensor([123.675, 116.28, 103.53]).view(3, 1, 1).cuda()
    pixel_std = torch.tensor([58.395, 57.12, 57.375]).view(3, 1, 1).cuda()

    image = (image - pixel_mean) / pixel_std

    h, w = image.shape[-2:]
    pad_size = image.shape[-2:]
    padh = side - h
    padw = side - w
    image = F.pad(image, (0, padw, 0, padh))

    sample = {
        "image": image[None, ...],
        "bbox": bbox[None, ...],
        "ori_shape": original_size,
        "pad_shape": pad_size,
    }

    return sample

def build_model(config_path, checkpoint_path):
    """Creates model from config path and checkpoint path"""

    with open(config_path) as f:
        utils.load_config(toml.load(f))

    model = networks.get_generator_m2m(
        seg=CONFIG.model.arch.seg, m2m=CONFIG.model.arch.m2m
    )
    model.cuda()

    checkpoint = torch.load(checkpoint_path)
    model.m2m.load_state_dict(
        utils.remove_prefix_state_dict(checkpoint["state_dict"]), strict=True
    )

    return model.eval()

@torch.no_grad()
def inference(
    model: networks.generator_m2m.sam_m2m,
    image_dict: dict,
    os8_width: Optional[int] = 10,
    os4_width: Optional[int] = 20,
    os1_width: Optional[int] = 10,
    twoside: Optional[bool] = False,
    maskguide: Optional[bool] = False,
):
    get_unknown_func = (
        utils.get_unknown_tensor_from_mask
        if twoside
        else utils.get_unknown_tensor_from_mask_oneside
    )
    _, pred, post_mask = model.forward_inference(image_dict)

    alpha_pred_os1, alpha_pred_os4, alpha_pred_os8 = (
        pred["alpha_os1"],
        pred["alpha_os4"],
        pred["alpha_os8"],
    )
    pad_h, pad_w = image_dict["pad_shape"][0], image_dict["pad_shape"][1]
    alpha_pred_os8 = alpha_pred_os8[..., :pad_h, :pad_w]
    alpha_pred_os4 = alpha_pred_os4[..., :pad_h, :pad_w]
    alpha_pred_os1 = alpha_pred_os1[..., :pad_h, :pad_w]

    alpha_pred_os8 = F.interpolate(
        alpha_pred_os8, image_dict["ori_shape"], mode="bilinear", align_corners=False
    )
    alpha_pred_os4 = F.interpolate(
        alpha_pred_os4, image_dict["ori_shape"], mode="bilinear", align_corners=False
    )
    alpha_pred_os1 = F.interpolate(
        alpha_pred_os1, image_dict["ori_shape"], mode="bilinear", align_corners=False
    )

    if maskguide:
        weight_os8 = get_unknown_func(post_mask, rand_width=os8_width, train_mode=False)
        post_mask[weight_os8 > 0] = alpha_pred_os8[weight_os8 > 0]
        alpha_pred = post_mask.clone().detach()
    else:
        alpha_pred = alpha_pred_os8.clone().detach()

    weight_os4 = get_unknown_func(alpha_pred, rand_width=os4_width, train_mode=False)
    alpha_pred[weight_os4 > 0] = alpha_pred_os4[weight_os4 > 0]

    weight_os1 = get_unknown_func(alpha_pred, rand_width=os1_width, train_mode=False)
    alpha_pred[weight_os1 > 0] = alpha_pred_os1[weight_os1 > 0]

    # alpha_pred = alpha_pred > 0.5 # threshold
    alpha_pred = alpha_pred[0].cpu().numpy() * 255
    alpha_pred = alpha_pred.transpose(1, 2, 0).astype("uint8")

    return alpha_pred

if __name__ == "__main__":
    CHECKPOINT = "checkpoints/mam_sam_vitb.pth"
    CONFIG_PATH = "config/MAM-ViTB-8gpu.toml"
    IMAGE_PATH = "./assets/demo.jpg"
    BBOX = [21, 109, 651, 1316]

    sample = prep_sample(IMAGE_PATH, BBOX)
    model = build_model(CONFIG_PATH, CHECKPOINT)
    alpha_pred = inference(model, sample)
    cv2.imwrite("inference_demo.jpg", alpha_pred)

Made a quick script from inference_benchmark.py to test a single image with a bounding box input.

inference_demo