snakers4 / silero-models

Silero Models: pre-trained speech-to-text, text-to-speech and text-enhancement models made embarrassingly simple
Other
4.96k stars 312 forks source link

ONNX version of the te_model #216

Closed pafullb closed 1 year ago

pafullb commented 1 year ago

Try to get the ONNX version of the te_model

I try to convert the 'te_model' https://models.silero.ai/te_models/v2_4lang_q.pt.

I didn't find the ONNX version of this model.

And I try these ways for the conversion

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
imp = package.PackageImporter("v2_4lang_q.pt")

model = imp.load_pickle("te_model", "model")

input_names = ["input"]

output_names = ["output"]
input_shape = (768)

torch_model = model.model

print(torch_model)

dummy_model_input = tokenizer(model.examples[0], return_tensors="pt")

torch.onnx.export(torch_model,               # model being run
                  args=tuple(dummy_model_input.values()),
                  f = "v2_4lang_q.onnx",  
                  input_shape,                         # model input (or a tuple for multiple inputs)
                  "v2_4lang_q.onnx",   # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=12,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['input'],   # the model's input names
                  output_names = ['output'], # the model's output names
                  dynamic_axes={'input' : {0 : 'batch_size'},    # variable length axes

                                'output' : {0 : 'batch_size'}})

But getting this error

RuntimeError: Only tuples, lists and Variables are supported as JIT inputs/outputs. Dictionaries and strings are also accepted, but their usage is not recommended. Here, received an input of unsupported type: torch._C.ScriptObject