onnx / keras-onnx

Convert tf.keras/Keras models to ONNX
Apache License 2.0
379 stars 109 forks source link

Keras2onnx support dynamic input? #684

Open Zjq9409 opened 3 years ago

Zjq9409 commented 3 years ago

from transformers import BertTokenizer, TFBertForSequenceClassification import tensorflow as tf

tokenizer = BertTokenizer.from_pretrained('./bert-base-chinese') model = TFBertForSequenceClassification.from_pretrained('./bert-base-chinese')

inputs = tokenizer("rch能保证是不是回因为环境差异带来运行错误。", return_tensors="tf") print(inputs)

logits = model(inputs) print(logits)

import keras2onnx onnx_model = keras2onnx.convert_keras(model, model.name) output_model_path = "chinese_roberta_l-12_H-768.onnx" keras2onnx.save_model(onnx_model, output_model_path)

I find keras2onnx do not have interface to support dynamic input and dynamic input length, like pytorch to onnx can specify column name and length: def export_onnx_model(args, model, tokenizer, onnx_model_path): with torch.no_grad(): inputs = {'input_ids': torch.ones(1,128, dtype=torch.int64), 'attention_mask': torch.ones(1,128, dtype=torch.int64), 'token_type_ids': torch.ones(1,128, dtype=torch.int64)} outputs = model(**inputs)

    symbolic_names = {0: 'batch_size', 1: 'max_seq_len'}
    torch.onnx.export(model,                                            # model being run
                (inputs['input_ids'],                             # model input (or a tuple for multiple inputs)
                inputs['attention_mask'],
                inputs['token_type_ids']),                                         # model input (or a tuple for multiple inputs)
                onnx_model_path,                                # where to save the model (can be a file or file-like object)
                opset_version=11,                                 # the ONNX version to export the model to
                do_constant_folding=True,                         # whether to execute constant folding for optimization
                input_names=['input_ids',                         # the model's input names
                            'input_mask',
                            'segment_ids'],
                output_names=['output'],                    # the model's output names
                dynamic_axes={'input_ids': symbolic_names,        # variable length axes
                            'input_mask' : symbolic_names,
                            'segment_ids' : symbolic_names})
    logger.info("ONNX Model exported to {0}".format(onnx_model_path))
oborchers commented 3 years ago

I do have the same issue, installing from pypi.

oborchers commented 3 years ago

@jianqianzhou: If you have the model saved somewhere, you can add the from_tf=True argument to the pytorch version and the transformers library will automatically convert the tf model to pytorsch.