facebookresearch / segment-anything

The repository provides code for running inference with the SegmentAnything Model (SAM), links for downloading the trained model checkpoints, and example notebooks that show how to use the model.
Apache License 2.0
47.18k stars 5.59k forks source link

Cannot export onnx #71

Open YaoJiawei329 opened 1 year ago

YaoJiawei329 commented 1 year ago

An Extraordinary work! Well, I try to export onnx, but error occurs. If opset=11, 12, 13, error message is: RuntimeError: Exporting the operator repeat_interleave to ONNX opset version 13 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub. else if opset=14, 15, 16, 17, error message is: ValueError: Unsupported ONNX opset version: 14

win11 12700H 3070ti-laptop pytorch1.8.2 onnx1.12

AllenZYJ commented 1 year ago

I get the same error: ValueError: Unsupported ONNX opset version: 17

YaoJiawei329 commented 1 year ago

I create a new conda env, use pytorch=1.12, and opset=1, solve the problem.

HighPoint commented 1 year ago

Try PyTorch 2.0. The requirements are likely PyTorch 2.0 and opset version 17.

julinfn commented 1 year ago

I got the same problem,17 to 12 is ok,but got new problem:torch_C.value object is not iterable. a problem about pytorch version?

torch:1.10 py:3.9.2

wavelet2008 commented 1 year ago

Exporing onnx model to out/dd.onnx... Traceback (most recent call last): File "/home/ubuntu/seg/segment-anything/scripts/export_onnx_model.py", line 180, in run_export( File "/home/ubuntu/seg/segment-anything/scripts/export_onnx_model.py", line 154, in run_export torch.onnx.export( File "/home/ubuntu/anaconda3/envs/face19/lib/python3.9/site-packages/torch/onnx/init.py", line 275, in export return utils.export(model, args, f, export_params, verbose, training, File "/home/ubuntu/anaconda3/envs/face19/lib/python3.9/site-packages/torch/onnx/utils.py", line 88, in export _export(model, args, f, export_params, verbose, training, input_names, output_names, File "/home/ubuntu/anaconda3/envs/face19/lib/python3.9/site-packages/torch/onnx/utils.py", line 689, in _export _model_to_graph(model, args, verbose, input_names, File "/home/ubuntu/anaconda3/envs/face19/lib/python3.9/site-packages/torch/onnx/utils.py", line 463, in _model_to_graph graph = _optimize_graph(graph, operator_export_type, File "/home/ubuntu/anaconda3/envs/face19/lib/python3.9/site-packages/torch/onnx/utils.py", line 200, in _optimize_graph graph = torch._C._jit_pass_onnx(graph, operator_export_type) File "/home/ubuntu/anaconda3/envs/face19/lib/python3.9/site-packages/torch/onnx/init.py", line 313, in _run_symbolic_function return utils._run_symbolic_function(*args, *kwargs) File "/home/ubuntu/anaconda3/envs/face19/lib/python3.9/site-packages/torch/onnx/utils.py", line 994, in _run_symbolic_function return symbolic_fn(g, inputs, **attrs) File "/home/ubuntu/anaconda3/envs/face19/lib/python3.9/site-packages/torch/onnx/symbolic_opset11.py", line 922, in repeat_interleave return torch.onnx.symbolic_opset9.repeat_interleave(g, self, repeats, final_dim) File "/home/ubuntu/anaconda3/envs/face19/lib/python3.9/site-packages/torch/onnx/symbolic_opset9.py", line 2064, in repeat_interleave for idx, r_split in enumerate(r_splits): TypeError: 'torch._C.Value' object is not iterable (Occurred when translating repeat_interleave).

MolianWH commented 1 year ago

It seems that it only supports pytorch version 2.0(cuda11.7). I updated my cuda and pytorch(failed with others versions)and it works. By the way, I modified onnxruntime.inferenceSession parameters at line168

ort_session = onnxruntime.InferenceSession(output,providers=['CUDAExecutionProvider'])
AllenZYJ commented 1 year ago

It is not throwing any errors now,when I updated pytorch to 2.0 and onnx 1.13.1.

UNeedCryDear commented 1 year ago

Hey,guys! In this PR: https://github.com/facebookresearch/segment-anything/pull/210 After changing torch.repeat_interleave() to torch.expand(),, I successfully exported it under torch1.8.2+opset=12, But I'm not sure how this will affect performance.

lauraset commented 1 year ago

@UNeedCryDear . It deed works.

InterstellarFang commented 1 year ago

Hey,guys! In this PR: #210 After changing torch.repeat_interleave() to torch.expand(),, I successfully exported it under torch1.8.2+opset=12, But I'm not sure how this will affect performance. @UNeedCryDear Hello!Which python file is this function (torch.repeat_interleave()) in,Can you tell me the location of this file?thanks a lot!

UNeedCryDear commented 1 year ago

@UNeedCryDear Hello!Which python file is this function (torch.repeat_interleave()) in,Can you tell me the location of this file?thanks a lot!

https://github.com/facebookresearch/segment-anything/pull/210/files

InterstellarFang commented 1 year ago

@UNeedCryDear Hello!Which python file is this function (torch.repeat_interleave()) in,Can you tell me the location of this file?thanks a lot!

https://github.com/facebookresearch/segment-anything/pull/210/files

@UNeedCryDearThank you! I made changes based on the code you provided(https://github.com/facebookresearch/segment-anything/pull/210/files), add four lines of code, but still reported an error(RuntimeError: Exporting the operator repeat_interleave to ONNX opset version 12 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub.) under torch1.8.1+opset=12

UNeedCryDear commented 1 year ago

image

The code with a pink background has been replaced and you need to remove it.

InterstellarFang commented 1 year ago

image

The code with a pink background has been replaced and you need to remove it.

@UNeedCryDear I have added # to these two sentences with a pink background,but still reported an error(RuntimeError: Exporting the operator repeat_interleave to ONNX opset version 12 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub.)under torch1.8.1+opset=12 Is this a problem with the torch version? I would like to confirm if changing the onnx opt default value is in these two files(notebooks/onnx_model_example.ipynb,scripts/export_onnx_model.py)

UNeedCryDear commented 1 year ago

show me the code you modefied.

InterstellarFang commented 1 year ago

show me code you modefied.

    # Expand per-image data in batch direction to be per-mask
    # src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
    src_shape = (tokens.shape[0],*image_embeddings.shape[1:])
    src = image_embeddings.expand(src_shape)
    src = src + dense_prompt_embeddings
    # pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
    pos_src_shape = (tokens.shape[0],*image_pe.shape[1:])
    pos_src = image_pe.expand(pos_src_shape)
    b, c, h, w = src.shape

@UNeedCryDear

UNeedCryDear commented 1 year ago

image

Search for repeat_interleave in the project, only here was calling. Your code is correct unless it is different from the function you are calling. So, have you saved your modifications?

InterstellarFang commented 1 year ago

image

Search for repeat_interleave in the project, only here was calling. Your code is correct unless it is different from the function you are calling. So, have you saved your modifications?

@UNeedCryDear yes,I saved my modifications.Should I try other values besides 12(onnx opset)?

InterstellarFang commented 1 year ago

image

Search for repeat_interleave in the project, only here was calling. Your code is correct unless it is different from the function you are calling. So, have you saved your modifications?

@UNeedCryDear hello!I don't know how to fix the error,can you give me some advice?thanks a lot!

UNeedCryDear commented 1 year ago

hello!I don't know how to fix the error,can you give me some advice?thanks a lot!

According to the error, it is a problem that was not successfully modified.You can search it like me throughout the project and find out where the modifications were not made correctly At the same time, if you are using Jupyter Notebook and colab, you may encounter issues with modified files being inconsistent with the actual running files. The correct approach is to git clone the code and make local modifications instead of pips. image

Finally, I will provide you .py that I can export which based on the modification of PR210, If you still cannot export, I suggest you try changing to PyTorch2.0. Good luck!

python export.py --checkpoint path/to/checkpoint --type vit_b --opset 12


import torch
import warnings
from segment_anything import sam_model_registry, SamPredictor
from segment_anything.utils.onnx import SamOnnxModel
import argparse
import onnx

def export_onnx( sam_checkpoint="sam_vit_b_01ec64.pth", model_type = "vit_b", opset=12, onnx_model_path="sam_onnx_example_maskdeocde.onnx"): sam = sam_model_registrymodel_type onnx_model = SamOnnxModel(sam, return_single_mask=True) dynamic_axes = { "point_coords": {1: "num_points"}, "point_labels": {1: "num_points"}, }

embed_dim = sam.prompt_encoder.embed_dim
embed_size = sam.prompt_encoder.image_embedding_size
mask_input_size = [4 * x for x in embed_size]
img_size=sam.image_encoder.img_size

img=torch.randn(1, 3, img_size,img_size, dtype=torch.float)
dynamic_shape = {'images': {0: 'batch', 2: 'height', 3: 'width'}}
dummy_inputs = {
    "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),
    "point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),
    "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
    "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
    "has_mask_input": torch.tensor([1], dtype=torch.float),
    "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),
}
output_names = ["masks", "iou_predictions", "low_res_masks"]

with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
    warnings.filterwarnings("ignore", category=UserWarning)
    with open(onnx_model_path, "wb") as f:
        torch.onnx.export(
            onnx_model,
            tuple(dummy_inputs.values()),
            f,
            export_params=True,
            verbose=False,
            opset_version=opset,
            do_constant_folding=True,
            input_names=list(dummy_inputs.keys()),
            output_names=output_names,
            dynamic_axes=dynamic_axes,
        )
model_onnx = onnx.load(onnx_model_path)  # load onnx model
onnx.checker.check_model(model_onnx)  # check onnx model
print("Done!")

def parse_opt(): parser = argparse.ArgumentParser() parser.add_argument("--checkpoint", type=str, default ="./model_weights/sam_vit_b_01ec64.pth", help="The path to the SAM model checkpoint.") parser.add_argument("--type", type=str, default="vit_b", help="In ['default', 'vit_h', 'vit_l', 'vit_b']. Which type of SAM model to export.") parser.add_argument("--opset",type=int,default=12,help="The ONNX opset version to use") parser.add_argument("--output", type=str,default ="sam_onnx_example_maskdeocde.onnx", help="The ONNX opset version to use") opt = parser.parse_args() return opt if name == 'main': opt = parse_opt() export_onnx(opt.checkpoint, opt.type,opt.opset,opt.output)

captainIT commented 4 months ago

Hey,guys! In this PR: #210 After changing torch.repeat_interleave() to torch.expand(),, I successfully exported it under torch1.8.2+opset=12, But I'm not sure how this will affect performance.

It is true that onnx can be exported successfully, but the web demo cannot be used normally.

UNeedCryDear commented 4 months ago

It is true that onnx can be exported successfully, but the web demo cannot be used normally.

I'm sorry I can't help you,I am not familiar with the web at all. If you need to use the web side, it is best to use the original code.