Open heorhiikalaichev opened 1 year ago
+1
Hi!
Thank's for you work! Could you share script which will process input of single image and trimap?
I have the same question. Do you know how to do this?
from typing import Optional
import cv2
import numpy as np
import toml
import torch
from segment_anything.utils.transforms import ResizeLongestSide
from torch.nn import functional as F
import networks
import utils
from utils import CONFIG
def prep_sample(image_path: str, bbox: list, side: Optional[int] = 1024) -> dict:
"""Loads image and bounding box with transformations
bbox (list(int)): [xmin, ymin, xmax, ymax]
side (int): default setting for SAM inference space
"""
bbox = np.array(bbox)
transform = ResizeLongestSide(side)
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
original_size = image.shape[:2]
image = transform.apply_image(image)
image = torch.as_tensor(image).cuda()
image = image.permute(2, 0, 1).contiguous()
bbox = transform.apply_boxes(bbox, original_size)
bbox = torch.as_tensor(bbox, dtype=torch.float).cuda()
pixel_mean = torch.tensor([123.675, 116.28, 103.53]).view(3, 1, 1).cuda()
pixel_std = torch.tensor([58.395, 57.12, 57.375]).view(3, 1, 1).cuda()
image = (image - pixel_mean) / pixel_std
h, w = image.shape[-2:]
pad_size = image.shape[-2:]
padh = side - h
padw = side - w
image = F.pad(image, (0, padw, 0, padh))
sample = {
"image": image[None, ...],
"bbox": bbox[None, ...],
"ori_shape": original_size,
"pad_shape": pad_size,
}
return sample
def build_model(config_path, checkpoint_path):
"""Creates model from config path and checkpoint path"""
with open(config_path) as f:
utils.load_config(toml.load(f))
model = networks.get_generator_m2m(
seg=CONFIG.model.arch.seg, m2m=CONFIG.model.arch.m2m
)
model.cuda()
checkpoint = torch.load(checkpoint_path)
model.m2m.load_state_dict(
utils.remove_prefix_state_dict(checkpoint["state_dict"]), strict=True
)
return model.eval()
@torch.no_grad()
def inference(
model: networks.generator_m2m.sam_m2m,
image_dict: dict,
os8_width: Optional[int] = 10,
os4_width: Optional[int] = 20,
os1_width: Optional[int] = 10,
twoside: Optional[bool] = False,
maskguide: Optional[bool] = False,
):
get_unknown_func = (
utils.get_unknown_tensor_from_mask
if twoside
else utils.get_unknown_tensor_from_mask_oneside
)
_, pred, post_mask = model.forward_inference(image_dict)
alpha_pred_os1, alpha_pred_os4, alpha_pred_os8 = (
pred["alpha_os1"],
pred["alpha_os4"],
pred["alpha_os8"],
)
pad_h, pad_w = image_dict["pad_shape"][0], image_dict["pad_shape"][1]
alpha_pred_os8 = alpha_pred_os8[..., :pad_h, :pad_w]
alpha_pred_os4 = alpha_pred_os4[..., :pad_h, :pad_w]
alpha_pred_os1 = alpha_pred_os1[..., :pad_h, :pad_w]
alpha_pred_os8 = F.interpolate(
alpha_pred_os8, image_dict["ori_shape"], mode="bilinear", align_corners=False
)
alpha_pred_os4 = F.interpolate(
alpha_pred_os4, image_dict["ori_shape"], mode="bilinear", align_corners=False
)
alpha_pred_os1 = F.interpolate(
alpha_pred_os1, image_dict["ori_shape"], mode="bilinear", align_corners=False
)
if maskguide:
weight_os8 = get_unknown_func(post_mask, rand_width=os8_width, train_mode=False)
post_mask[weight_os8 > 0] = alpha_pred_os8[weight_os8 > 0]
alpha_pred = post_mask.clone().detach()
else:
alpha_pred = alpha_pred_os8.clone().detach()
weight_os4 = get_unknown_func(alpha_pred, rand_width=os4_width, train_mode=False)
alpha_pred[weight_os4 > 0] = alpha_pred_os4[weight_os4 > 0]
weight_os1 = get_unknown_func(alpha_pred, rand_width=os1_width, train_mode=False)
alpha_pred[weight_os1 > 0] = alpha_pred_os1[weight_os1 > 0]
# alpha_pred = alpha_pred > 0.5 # threshold
alpha_pred = alpha_pred[0].cpu().numpy() * 255
alpha_pred = alpha_pred.transpose(1, 2, 0).astype("uint8")
return alpha_pred
if __name__ == "__main__":
CHECKPOINT = "checkpoints/mam_sam_vitb.pth"
CONFIG_PATH = "config/MAM-ViTB-8gpu.toml"
IMAGE_PATH = "./assets/demo.jpg"
BBOX = [21, 109, 651, 1316]
sample = prep_sample(IMAGE_PATH, BBOX)
model = build_model(CONFIG_PATH, CHECKPOINT)
alpha_pred = inference(model, sample)
cv2.imwrite("inference_demo.jpg", alpha_pred)
Made a quick script from inference_benchmark.py to test a single image with a bounding box input.
Hi!
Thank's for you work! Could you share script which will process input of single image and trimap?