facebookresearch / esm

Evolutionary Scale Modeling (esm): Pretrained language models for proteins
MIT License
3.16k stars 627 forks source link

"invalid dimensions for input" When Running with ONNXRuntime #61

Closed ashrafgt closed 3 years ago

ashrafgt commented 3 years ago

Summary:

I'm looking to use ONNX and ONNXRuntime to speed up our workloads to use ESM-1b.

I converted the serialized model esm1b_t33_650M_UR50S.pt to a .onnx graph using torch.onnx then explicitly applied extended optimizations (including conversion to float16) using onnxruntime_tools.

When trying to run an ONNXRuntime inference session with the following inputs:

data = [
    ("protein1", "VLAGG"),
    ("protein2", "KALTARQ"),
]

I get:

InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Got invalid dimensions for input: input.1 for the following indices
index: 1 Got: 8 Expected: 9

ONNX Graph Conversion and ONNXRuntime Inference Code:

Before the ONNXRuntime inference code, I'd like to share my Pytorch-to-ONNX conversion process first. Just in case the issue resides there:

export MODEL_PATH=/mnt/models/esm1b/esm1b_t33_650M_UR50S.pt # model file downloaded locally
export CONVERTED_GRAPH_PATH=/tmp/models/onnx_esm/graph.onnx # intermediate storage for the graph and the external data binaries ("/tmp/models/onnx_esm" must be created beforehand)
export OPTIMIZED_GRAPH_PATH=/mnt/models/onnx_esm/graph.onnx # final form of the graph, encapsulated within a single 1.3G file ("/mnt/models/onnx_esm" must be created beforehand)

python convert_onnx_esm.py --model-path $MODEL_PATH --converted-model-path $CONVERTED_GRAPH_PATH
python -m onnxruntime_tools.optimizer_cli --float16 --opt_level 99 --use_gpu --model_type bert --hidden_size 1024 --num_heads 16 --input $CONVERTED_GRAPH_PATH --output $OPTIMIZED_GRAPH_PATH

This is the source code for convert_onnx_esm.py:

import os
import torch
import torch.onnx
import argparse
from esm.pretrained import load_model_and_alphabet_local

parser = argparse.ArgumentParser()

parser.add_argument("--model-path", type=str, required=True)
parser.add_argument("--converted-model-path", type=str, required=True)
args = parser.parse_args()

model, alphabet = load_model_and_alphabet_local(args.model_path)
batch_converter = alphabet.get_batch_converter()

data = [
    ("protein1", "VLAGG"),
    ("protein2", "KALTARQ"),
]
batch_labels, batch_strs, batch_tokens = batch_converter(data)

with torch.no_grad():
    torch.onnx.export(model,
        batch_tokens,
        args.converted_model_path,
        use_external_data_format=True,
        opset_version=12,
        do_constant_folding=True
    )

Now for the inference code which produced the exception:

import os
import numpy as np
import argparse
from onnxruntime import (
    GraphOptimizationLevel,
    InferenceSession,
    SessionOptions,
    get_device,
)
from esm.data import Alphabet

provider = "CUDAExecutionProvider" if get_device() == "GPU" else "CPUExecutionProvider"

parser = argparse.ArgumentParser()

parser.add_argument("--optimized-model-path", type=str, required=True)
args = parser.parse_args()

options = SessionOptions()
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
model = InferenceSession(args.optimized_model_path, options
)
model.set_providers([provider])
alphabet = Alphabet.from_architecture("protein_bert_base")
batch_converter = alphabet.get_batch_converter()

data = [
    ("protein1", "VLAGG"),
    ("protein2", "KALTARQ"),
]
batch_labels, batch_strs, batch_tokens = batch_converter(data)

output = model.run(None, {"input.1": batch_tokens.numpy()}) # input name "input.1" should be the default when exporting with `torch.onnx.export()` 

Environment:

Cuda 11.2, CudNN=8.1.1.33, and Python 3.8.5 with packages:

fair-esm==0.3.0
onnx==1.8.1
onnxconverter-common==1.6.0
onnxruntime-gpu==1.7.0
onnxruntime-tools==1.6.0

Note: If we can find a solution for this issue, I was wondering if it's a good idea for me to clean up and add the model conversion and the inference example as a contribution to the repo.

tomsercu commented 3 years ago

Just skimming this, seems like the issue could be variable length sequences? You may want to either tell ONNX that the length dimension is variable, or just pad everything to maxlength 1024. Another option: did you consider using TorchScript https://pytorch.org/docs/stable/jit.html

ashrafgt commented 3 years ago

Thank you for the reply!

TorchScript is definitely worth considering, so I'll give that a try. Ideally, I'd also get ONNX to work so that we can compare the performance improvements that come with each one of the two.

I'll try using torch.onnx.export(dynamic_axes) (though I'm honestly not sure how the syntax for that attribute should be) and share my updates here.

ashrafgt commented 3 years ago

I added dynamic_axes, input_names, and output_names to the export, and now the forward pass works with varied-length sequences and different batch sizes:

with torch.no_grad():
    torch.onnx.export(model,
        batch_tokens,
        args.converted_model_path,
        use_external_data_format=True,
        opset_version=11,
        do_constant_folding=True,
        input_names=["inputs"],
        output_names=["outputs"],
        dynamic_axes={"inputs": [0, 1]}
    )

The only issue that remains for me is passing the argument repr_layers=[33]) in order to extract embeddings. It's not exactly related to the original issue in this post, and after some more trials, I'll try to get help from the pytorch community and share the solution in a comment here.