facebookresearch / segment-anything

The repository provides code for running inference with the SegmentAnything Model (SAM), links for downloading the trained model checkpoints, and example notebooks that show how to use the model.
Apache License 2.0
47.87k stars 5.66k forks source link

ONNX model produces worse result than Pytorch counterpart #714

Open james-imi opened 8 months ago

james-imi commented 8 months ago

So I have the following for the Pytorch prediction for a finetuned model with only bounding boxes.

Pytorch Prediction

bbox = [1055,  412, 1286,  991]
bbox = np.array(bbox)

predictor = SamPredictor(sam)
predictor.set_image(image)
masks, scores, logits = predictor.predict(
    box=bbox,
    multimask_output=False,
)
plt.imshow(masks[0])

and get the image like this (which is correct) image

ONNX Prediction

# Convert to ONNX
onnx_model = SamOnnxModel(sam, return_single_mask=True)

embed_size = sam.prompt_encoder.image_embedding_size
dummy_inputs = {
    "image_embeddings": torch.randn(1, sam.prompt_encoder.embed_dim, *embed_size, dtype=torch.float),
    "point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),
    "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
    "mask_input": torch.randn(1, 1, *([4 * x for x in embed_size]), dtype=torch.float),
    "has_mask_input": torch.tensor([1], dtype=torch.float),
    "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),
}

with open(model_output_path, "wb") as f:
    torch.onnx.export(
        onnx_model,
        tuple(dummy_inputs.values()),
        f,
        export_params=True,
        verbose=False,
        opset_version=15,
        do_constant_folding=True,
        input_names=list(dummy_inputs.keys()),
        output_names=["masks", "iou_predictions", "low_res_masks"],
        dynamic_axes={
            "point_coords": {1: "num_points"}, "point_labels": {1: "num_points"},
        },
    )  

# Start onnx
ort_session = onnxruntime.InferenceSession(model_output_path)

# Encode bounding box
onnx_box_coords = input_box.reshape(2, 2)
onnx_box_labels = np.array([2, 3])

onnx_coord = onnx_box_coords[None, :, :]
onnx_label = onnx_box_labels[None, :].astype(np.float32)
onnx_coord = predictor.transform.apply_coords(onnx_coord, image.shape[:2]).astype(np.float32)

onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
onnx_has_mask_input = np.array([0], dtype=np.float32)

ort_inputs = {
    "image_embeddings": image_embedding,
    "point_coords": onnx_coord,
    "point_labels": onnx_label,
    "mask_input": onnx_mask_input,
    "has_mask_input": onnx_has_mask_input,
    "orig_im_size": np.array(image.shape[:2], dtype=np.float32)
}

# Predict
masks, _, _ = ort_session.run(None, ort_inputs)
masks = masks > predictor.model.mask_threshold
plt.imshow(masks[0][0])

and get this wrong one image

Any possible idea why?

zhangzeyang000 commented 5 months ago

try this again, maybe you can get a sample mask: mask = masks[0][0] mask = (mask > 0).astype('uint8')*255 plt.imshow(mask )