Closed ashrafgt closed 3 years ago
Just skimming this, seems like the issue could be variable length sequences? You may want to either tell ONNX that the length dimension is variable, or just pad everything to maxlength 1024. Another option: did you consider using TorchScript https://pytorch.org/docs/stable/jit.html
Thank you for the reply!
TorchScript is definitely worth considering, so I'll give that a try. Ideally, I'd also get ONNX to work so that we can compare the performance improvements that come with each one of the two.
I'll try using torch.onnx.export(dynamic_axes)
(though I'm honestly not sure how the syntax for that attribute should be) and share my updates here.
I added dynamic_axes
, input_names
, and output_names
to the export, and now the forward pass works with varied-length sequences and different batch sizes:
with torch.no_grad():
torch.onnx.export(model,
batch_tokens,
args.converted_model_path,
use_external_data_format=True,
opset_version=11,
do_constant_folding=True,
input_names=["inputs"],
output_names=["outputs"],
dynamic_axes={"inputs": [0, 1]}
)
The only issue that remains for me is passing the argument repr_layers=[33])
in order to extract embeddings. It's not exactly related to the original issue in this post, and after some more trials, I'll try to get help from the pytorch community and share the solution in a comment here.
Summary:
I'm looking to use ONNX and ONNXRuntime to speed up our workloads to use ESM-1b.
I converted the serialized model
esm1b_t33_650M_UR50S.pt
to a.onnx
graph usingtorch.onnx
then explicitly applied extended optimizations (including conversion tofloat16
) usingonnxruntime_tools
.When trying to run an ONNXRuntime inference session with the following inputs:
I get:
ONNX Graph Conversion and ONNXRuntime Inference Code:
Before the ONNXRuntime inference code, I'd like to share my Pytorch-to-ONNX conversion process first. Just in case the issue resides there:
This is the source code for
convert_onnx_esm.py
:Now for the inference code which produced the exception:
Environment:
Cuda 11.2, CudNN=8.1.1.33, and Python 3.8.5 with packages:
Note: If we can find a solution for this issue, I was wondering if it's a good idea for me to clean up and add the model conversion and the inference example as a contribution to the repo.