Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.55k stars 3.4k forks source link

Add checks for model spec and matching output values in `to_onnx()` method #7279

Open addisonklinke opened 3 years ago

addisonklinke commented 3 years ago

🚀 Feature

The official PyTorch tutorial for exporting to ONNX includes checking the model spec as well as confirming the output values match the PyTorch version of the model. The current Lightning implementation of to_onnx() only calls torch.onnx.export but skips these additional checks. For best practice, I think they should be included (if not by default, then with an optional boolean flag to turn them on)

Motivation

  1. Follow the official tutorial as closely as possible
  2. Ensure the ONNX version of the model does not produce unexpected results before it's used in production

Pitch

Add something along the lines of the following code to the end of the current to_onnx() method

import onnx
import onnxruntime
import numpy as np

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

onnx_model = onnx.load(file_path)
onnx.checker.check_model(onnx_model)
ort_session = onnxruntime.InferenceSession(file_path)
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(input_sample)}
ort_outs = ort_session.run(None, ort_inputs)
for ort_out, torch_out in zip(kwargs["example_outputs"], ort_outs):
    np.testing.assert_allclose(to_numpy(torch_out), ort_out, rtol=1e-03, atol=1e-05)

Alternatives

Users override the method and implement the checks themselves. However, I think these are general enough that they should be included in the base method to limit boilerplate code

Additional context

edenlightning commented 3 years ago

Thanks for the issue @addisonklinke ! want to try and submit a PR?

vballoli commented 3 years ago

I've added a draft PR trying to solve this issue. I've tackled it in a very inefficient way, but if this is the way to go, I'll be happy to clean it up and convert it into an actual PR.

addisonklinke commented 3 years ago

@vballoli Thanks for getting the ball rolling on this one. I haven't had much time to work on a PR, but I will review yours and leave my comments there where applicable

stale[bot] commented 3 years ago

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

addisonklinke commented 3 years ago

PR (https://github.com/PyTorchLightning/pytorch-lightning/pull/7458) still in progress