vietanhdev / samexporter

Export Segment Anything Models to ONNX
https://pypi.org/project/samexporter/
MIT License
239 stars 29 forks source link

How to export decoder for mobile net model? #13

Open CriusFission opened 6 months ago

CriusFission commented 6 months ago

I've exported a mobile net model using the export_encoder script using mobilenet weights for mobilenet model. Now how do I export decoder for this? I cannot see a mobile type in export_decoder script, also should I pass the same mobile_sam.pt or vit_h.pth model as the check point for decoder? Any help is appreciated.

Tbolp commented 1 week ago
def run_export(
    model_type: str,
    checkpoint: str,
    output: str,
    use_preprocess: bool,
    opset: int,
    gelu_approximate: bool = False,
):
    print("Loading model...")
    if model_type == "mobile":
        checkpoint = torch.load(checkpoint, map_location="cpu")
        sam = setup_model()
        sam.load_state_dict(checkpoint, strict=True)
    else:
        sam = sam_model_registry[model_type](checkpoint=checkpoint)

I copied this code from export_encoder.py to export_decoder.py, and ran export_decoder.py with --model-type mobile successfully, but I did not check if the result is correct.