MIC-DKFZ / nnUNet

Apache License 2.0
5.9k stars 1.76k forks source link

ONNX export #2407

Open ogencoglu opened 3 months ago

ogencoglu commented 3 months ago

Do you support ONNX export and is it tested? Any examples of it would be appreciated.

ZxnSnowy commented 3 months ago

I want to know the same question, can you give an example? Thank you very much.

rubencardenes commented 3 weeks ago

Hi, conversion to ONNX can be done like this:

def convert_to_ONNX(
    model_dir: str,
    onnx_model_path: str,
    batch_size: int = 1,
    tile_size: int =512,
    fold: int = 1,
    num_channels: int = 3,  

):  
    model_path = f"{model_dir}/fold_{fold}/checkpoint_best.pth"
    dataset_json = load_json(join(model_dir, 'dataset.json'))
    plans = load_json(join(model_dir, 'plans.json'))
    plans_manager = PlansManager(plans)

    parameters = []
    use_folds = [fold]
    for i, f in enumerate(use_folds):
        f = int(f) if f != 'all' else f
        checkpoint = torch.load(join(model_dir, f'fold_{f}', model_path),
                                map_location=torch.device('cpu'))
        if i == 0:
            trainer_name = checkpoint['trainer_name']
            configuration_name = checkpoint['init_args']['configuration']
            inference_allowed_mirroring_axes = checkpoint['inference_allowed_mirroring_axes'] if \
                'inference_allowed_mirroring_axes' in checkpoint.keys() else None

        parameters.append(checkpoint['network_weights'])

    configuration_manager = plans_manager.get_configuration(configuration_name)
    # restore network
    num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json)
    trainer_class = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"),
                                                    trainer_name, 'nnunetv2.training.nnUNetTrainer')
    model = trainer_class.build_network_architecture(
            configuration_manager.network_arch_class_name,
            configuration_manager.network_arch_init_kwargs,
            configuration_manager.network_arch_init_kwargs_req_import,
            num_input_channels,
            plans_manager.get_label_manager(dataset_json).num_segmentation_heads,
            enable_deep_supervision=False
        )
    for params in parameters:
        model.load_state_dict(params)
    model.eval()

    # convert to onnx
    dummy = torch.randn(
        batch_size, num_channels, tile_size, tile_size, requires_grad=True
    )
    torch.onnx.export(
        model,
        dummy,
        onnx_model_path,
        verbose=False,
        export_params=True,
        opset_version=12,
        do_constant_folding=True,
        input_names=["input"],
        output_names=["output"],
    )

Hope it helps