facebookresearch / segment-anything

The repository provides code for running inference with the SegmentAnything Model (SAM), links for downloading the trained model checkpoints, and example notebooks that show how to use the model.
Apache License 2.0
47.87k stars 5.66k forks source link

Incorret result scaling #753

Open neko-para opened 6 months ago

neko-para commented 6 months ago

Currently, my program can perfectly work on the demo pictures(e.g. images/truck.jpg). But when I switched to my own pngs, the result seems to be scaled incorrectly.

For instance, the shown size of image below is 432x770, while the result mask seems to be only in 245x770. The image been processed

432*432=188624
245*770=188650

126200fa37b61eea6d75a2465ff99997

import sys
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt
from segment_anything import sam_model_registry, SamPredictor
from segment_anything.utils.onnx import SamOnnxModel

import json
import warnings

import onnxruntime
from onnxruntime.quantization import QuantType
from onnxruntime.quantization.quantize import quantize_dynamic

def show_mask(mask, ax):
    color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))   

# checkpoint = "sam_vit_h_4b8939.pth"
checkpoint = "sam_vit_b_01ec64.pth"
model_type = "vit_b"

sam = sam_model_registry[model_type](checkpoint=checkpoint)

# onnx_model_path = None  # Set to use an already exported model, then skip to the next section.

onnx_model_path = "sam_onnx_example.onnx"

onnx_model = SamOnnxModel(sam, return_single_mask=True)

dynamic_axes = {
    "point_coords": {1: "num_points"},
    "point_labels": {1: "num_points"},
}

embed_dim = sam.prompt_encoder.embed_dim
embed_size = sam.prompt_encoder.image_embedding_size
mask_input_size = [4 * x for x in embed_size]
dummy_inputs = {
    "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),
    "point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),
    "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
    "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
    "has_mask_input": torch.tensor([1], dtype=torch.float),
    "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),
}
output_names = ["masks", "iou_predictions", "low_res_masks"]

with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
    warnings.filterwarnings("ignore", category=UserWarning)
    with open(onnx_model_path, "wb") as f:
        torch.onnx.export(
            onnx_model,
            tuple(dummy_inputs.values()),
            f,
            export_params=True,
            verbose=False,
            opset_version=17,
            do_constant_folding=True,
            input_names=list(dummy_inputs.keys()),
            output_names=output_names,
            dynamic_axes=dynamic_axes,
        )

'''
onnx_model_quantized_path = "sam_onnx_quantized_example.onnx"
quantize_dynamic(
    model_input=onnx_model_path,
    model_output=onnx_model_quantized_path,
    optimize_model=True,
    per_channel=False,
    reduce_range=False,
    weight_type=QuantType.QUInt8,
)
onnx_model_path = onnx_model_quantized_path
'''

ort_session = onnxruntime.InferenceSession(onnx_model_path)

# sam.to(device='cuda')
predictor = SamPredictor(sam)

while True:
  print('-file')
  image_path = input()
  if image_path == '-q':
    break
  image = cv2.imread(image_path)
  print(f"size: {image.shape}", file=sys.stderr)

  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

  predictor.set_image(image)

  image_embedding = predictor.get_image_embedding().cpu().numpy()

  # np.save('image_embedding.npy', image_embedding)

  print("-loaded")

  while True:
    point_str = input()
    if point_str == '-q':
      break
    point_info = json.loads(point_str)

    input_point = np.array(point_info)
    input_label = np.array([1])

    onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :]
    onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[None, :].astype(np.float32)

    onnx_coord = predictor.transform.apply_coords(onnx_coord, image.shape[:2]).astype(np.float32)

    onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
    onnx_has_mask_input = np.zeros(1, dtype=np.float32)

    ort_inputs = {
        "image_embeddings": image_embedding,
        "point_coords": onnx_coord,
        "point_labels": onnx_label,
        "mask_input": onnx_mask_input,
        "has_mask_input": onnx_has_mask_input,
        "orig_im_size": np.array(image.shape[:2], dtype=np.float32)
    }

    masks, _, low_res_logits = ort_session.run(None, ort_inputs)
    masks = masks > predictor.model.mask_threshold

    h, w = masks.shape[-2:]
    print(f"sizp: {h} {w}", file=sys.stderr)

    print(masks.shape, file=sys.stderr)
    y_coords, x_coords = np.where(masks.reshape(h, w))

    if len(x_coords) == 0:
      print(f'-result [0, 0, 0, 0]')
    else:
      x_min, x_max = x_coords.min(), x_coords.max()
      y_min, y_max = y_coords.min(), y_coords.max()

      print(f'-result [{x_min}, {y_min}, {x_max}, {y_max}]')

    plt.figure(figsize=(10,10))
    plt.imshow(image)
    show_mask(masks, plt.gca())
    show_points(input_point, input_label, plt.gca())
    show_box([x_min, y_min, x_max, y_max], plt.gca())
    plt.axis('off')
    plt.show()
neko-para commented 6 months ago

stderr of the processing.

size: (1280, 720, 3)
{ pts: [ [ 707.1875, 55.8046875 ] ] }
sizp: 1280 720
(1, 1, 1280, 720)
heyoeyo commented 6 months ago

The SAM model internally scales the input image to fit inside a 1024x1024 resolution and uses padding to fill out the missing space, which would be 'to the right' of your image in this case (to fill in the narrow side of the image). The mask decoder is supposed to remove this padding, which requires knowing the size of the original image (through the orig_im_size input most likely).

In this case, it looks like the cropping has been flipped: the mask looks cropped at the bottom (judging by the misalignment of the mask with the search bar part of the image) instead of removing the padding on the right (which is why it looks like there's a gap on the right). The width of the displayed mask result also seems to confirm this. The original image would have taken up 56.25% (720/1280) of the width of the internal padded image and 56.25% of your displayed image (432px) is 243px, very close to the maximum expected width of the resulting mask if the padding isn't removed.

As for fixing it, I'm not very familiar with the onnx side of things, but maybe the orig_im_size input needs to be made dynamic? Otherwise I would try flipping the order of the height & width values given as the orig_im_size input.

neko-para commented 6 months ago

@heyoeyo After switching orig_im_size, it seems that sam truncate the longer side with the same ratio. I've also tried pass [h, h] or [w, w], which doesn't work either.

switch width & height

4aae4445900c905172d26de9d4db5d64

pass height & height

932a9847979bce02374e9ed5b039567f

pass width & width

28770cb6439aff069e2d223b9965c855

heyoeyo commented 6 months ago

Weird! It definitely seems like there's something wrong with the cropping and/or scaling of the mask result to remove the input padding. Swapping the width & height at least seems to fix the removal of the right-side padding (judging from the fact that the mask is horizontally aligned correctly), but it's clearly messing up the scaling still. Though looking at the onnx version of the model, the post-processing code looks ok to me...

As a sanity check, it might be worth manually handling the scaling/padding removal (using the low_res_logits output of the model), just to be sure that the correct transformations are being done. The basic steps are:

  1. Resize the low_res_logits to 1024x1024
  2. Crop the top-left portion of the 1024x1024 mask according to how big your input image would've been when scaled to this size. The predictor should have stored this as predictor.input_size, in this case I think it should be: 1024x576
  3. Scale the cropped mask back to the original image size (1280x720) You can try visualizing the result after each step to make sure the mask makes sense.

Assuming the masks come out as an np.array, I think something like this should work:

# Show low-res mask result after upscaling
result_uint8 = np.uint8((low_res_logits.squeeze() > 0) * 255)
scaled_uint8 = cv2.resize(result_uint8, dsize=(1024,1024))
cv2.imshow("Scaled low-res result", result_uint8)
cv2.waitKey(250)

# Show result after removing padding
cropped_uint8 = scaled_uint8[0:1024, 0:576]
cv2.imshow("Cropped result", cropped_uint8)
cv2.waitKey(250)

# Show final mask scaled back to original size
final_uint8 = cv2.resize(cropped_uint8, dsize=(720,1280))
cv2.imshow("Final result", final_uint8)

# Show windows until a keypress occurs, then close them all
cv2.waitKey(0)
cv2.destroyAllWindows()

This should pop-up a bunch of windows to show the intermediate results. The mask will look worse, since the thresholding (>0 check) is happening before scaling, but it should at least give a sense of whether the mask is being cropped/scaled properly, or if something is wrong with the sizings.