JJGO / UniverSeg

UniverSeg: Universal Medical Image Segmentation
Apache License 2.0
504 stars 51 forks source link

Torch to ONNX conversion is very slow #16

Open matt-kh opened 1 year ago

matt-kh commented 1 year ago

When converting Torch model to ONNX, the conversion ran for more than 8 hours without any exception.

device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Model Initialization
encoder_blocks = [64,] * 4
weights_url = "https://github.com/JJGO/UniverSeg/releases/download/weights/universeg_v1_nf64_ss64_STA.pt"
model = UniverSeg(encoder_blocks=encoder_blocks)
state_dict = torch.hub.load_state_dict_from_url(weights_url)
model.load_state_dict(state_dict)
_ = model.to(device)
_ = model.eval()

# Dummy inputs
torch.manual_seed(42)
target_image = torch.randn(1, 1, 256, 256, device=device)
support_images = torch.randn(1, 64, 1, 256, 256, device=device)
support_labels = torch.randn(1, 64, 1, 256, 256, device=device)

# Onnx conversion
input_names = ["target_image", "support_images", "support_labels"]
output_names = ["logits"]
torch.onnx.export(
    model=model,
    args=(target_image, support_images, support_labels),
    f=export_path,
    input_names=input_names,
    output_names=output_names, 
    export_params=True,
    do_constant_folding=True,
    dynamic_axes={
        "target_image":{0: "batch", 1: "channel", 2: "height", 3: "width"},
        "support_images": {0: "batch", 1:"support_size", 2: "channel", 3: "height", 4: "width"},
        "support_labels": {0: "batch", 1:"support_size", 3: "height", 4: "width"},
        "logits" : {0: "batch", 2: "height", 3: "width"}
    },
    verbose=True,
    opset_version=16,
)

These warnings from Einops package .../einops/einops.py are found during conversion:

TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
unknown: Set[str] = {axis for axis in composite_axis if axis_name2known_length[axis] == _unknown_axis_length}

However, no exceptions raised from this code. I am not sure whether the tracer warnings caused by Einops are related to the indefinite run of torch.onnx.export().

I appreciate any help for this issue, thank you.