chongzhou96 / EdgeSAM

Official PyTorch implementation of "EdgeSAM: Prompt-In-the-Loop Distillation for On-Device Deployment of SAM"
https://mmlab-ntu.com/project/edgesam/
Other
931 stars 42 forks source link

how to create 4 batch_size onnx? #32

Open yanzongs opened 2 months ago

yanzongs commented 2 months ago

when i change the code as:

def export_decoder_to_onnx(sam, args, batch_size=4):
    sam_decoder = SamCoreMLModel(
        model=sam,
        use_stability_score=args.use_stability_score
    )
    sam_decoder.eval()

    if args.gelu_approximate:
        for n, m in sam.named_modules():
            if isinstance(m, torch.nn.GELU):
                m.approximate = "tanh"

    embed_dim = sam.prompt_encoder.embed_dim
    embed_size = sam.prompt_encoder.image_embedding_size

    image_embeddings = torch.randn(batch_size, embed_dim, *embed_size, dtype=torch.float)
    point_coords = torch.randint(low=0, high=1024, size=(batch_size, 5, 2), dtype=torch.float)
    point_labels = torch.randint(low=0, high=4, size=(batch_size, 5), dtype=torch.float)

    # Define the input names and output names
    input_names = ["image_embeddings", "point_coords", "point_labels"]
    output_names = ["scores", "masks"]

    # Export the decoder model to ONNX format
    onnx_decoder_filename = args.checkpoint.replace('.pth', '_decoder.onnx')
    torch.onnx.export(
        sam_decoder,
        (image_embeddings, point_coords, point_labels),
        onnx_decoder_filename,
        input_names=input_names,
        output_names=output_names,
        opset_version=13,  # Use an appropriate ONNX opset version
        dynamic_axes={
            "image_embeddings": {0: "batch_size"},
            "point_coords": {0: "batch_size", 1: "num_points"},
            "point_labels": {0: "batch_size", 1: "num_points"}
        },
        verbose=False
    )

    print(f"Exported ONNX decoder model to {onnx_decoder_filename}")

there is a error like:

  File "f:\ai_code\edgesam\edge_sam\modeling\transformer.py", line 165, in forward
    k = keys + key_pe
RuntimeError: The size of tensor a (16) must match the size of tensor b (4) at non-singleton dimension 0

what to do can i solve it?