opensearch-project / opensearch-py-ml

Apache License 2.0
32 stars 62 forks source link

[ENHANCEMENT] use torch.onnx.export instead of convert_graph_to_onnx #347

Open HenryL27 opened 9 months ago

HenryL27 commented 9 months ago

Is your feature request related to a problem? This is related to compiling ONNX models for upload to opensearch

two problems with the status quo:

  1. transformers.convert_graph_to_onnx.convert will be deprecated in the next major version of huggingface transformers
  2. transformers.convert_graph_to_onnx.convert only grabs the base model; so the head on top of the base logits is left off. For embedding models, we got around this by implementing the pooling layer in ml-commons for ONNX models, but for other pretrained classification heads (e.g. cross-encoders) this is simply impossible.

What solution would you like? Instead use torch.onnx.export. An example (that implements this for cross encoders):

torch.onnx.export(
    model=model,
    args=(features['input_ids'], features['attention_mask'], features['token_type_ids']),
    f=f"/tmp/{mname}.onnx",
    input_names=['input_ids', 'attention_mask', 'token_type_ids'],
    output_names=['output'],
    dynamic_axes={
        'input_ids': {0: 'batch_size', 1: 'sequence_length'},
        'attention_mask': {0: 'batch_size', 1: 'sequence_length'},
        'token_type_ids': {0: 'batch_size', 1: 'sequence_length'},
        'output': {0: 'batch_size'}
    }
)

usage is similar to torch.jit.trace, which we use for torchscript compilation

This will simplify the code in ml-commons that drives ONNX models

What alternatives have you considered? There are probably other ways to export a complete model to ONNX (and if we want to support TF we might need to look at options for that) but this seems pretty clean.

Do you have any additional context? Original comment

We should probably invest in supporting all the new kinds of models that will be coming from [RFC] Support more local model types in opensearch-py-ml.

dblock commented 3 months ago

[Triage -- attendees 1, 2, 3, 4, 5, 6, 7]

Can this be closed with #1615?