Open ogencoglu opened 3 months ago
I want to know the same question, can you give an example? Thank you very much.
Hi, conversion to ONNX can be done like this:
def convert_to_ONNX(
model_dir: str,
onnx_model_path: str,
batch_size: int = 1,
tile_size: int =512,
fold: int = 1,
num_channels: int = 3,
):
model_path = f"{model_dir}/fold_{fold}/checkpoint_best.pth"
dataset_json = load_json(join(model_dir, 'dataset.json'))
plans = load_json(join(model_dir, 'plans.json'))
plans_manager = PlansManager(plans)
parameters = []
use_folds = [fold]
for i, f in enumerate(use_folds):
f = int(f) if f != 'all' else f
checkpoint = torch.load(join(model_dir, f'fold_{f}', model_path),
map_location=torch.device('cpu'))
if i == 0:
trainer_name = checkpoint['trainer_name']
configuration_name = checkpoint['init_args']['configuration']
inference_allowed_mirroring_axes = checkpoint['inference_allowed_mirroring_axes'] if \
'inference_allowed_mirroring_axes' in checkpoint.keys() else None
parameters.append(checkpoint['network_weights'])
configuration_manager = plans_manager.get_configuration(configuration_name)
# restore network
num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json)
trainer_class = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"),
trainer_name, 'nnunetv2.training.nnUNetTrainer')
model = trainer_class.build_network_architecture(
configuration_manager.network_arch_class_name,
configuration_manager.network_arch_init_kwargs,
configuration_manager.network_arch_init_kwargs_req_import,
num_input_channels,
plans_manager.get_label_manager(dataset_json).num_segmentation_heads,
enable_deep_supervision=False
)
for params in parameters:
model.load_state_dict(params)
model.eval()
# convert to onnx
dummy = torch.randn(
batch_size, num_channels, tile_size, tile_size, requires_grad=True
)
torch.onnx.export(
model,
dummy,
onnx_model_path,
verbose=False,
export_params=True,
opset_version=12,
do_constant_folding=True,
input_names=["input"],
output_names=["output"],
)
Hope it helps
Do you support ONNX export and is it tested? Any examples of it would be appreciated.