nod-ai / SHARK-TestSuite

Temporary home of a test suite we are evaluating
Apache License 2.0
5 stars 35 forks source link

Fix bert nlp models failing input generation by overriding `token_type_ids` to limit its range to [0,2) and some refactoring #357

Closed renxida closed 1 month ago

renxida commented 1 month ago

Note: this causes each nlp model to create an additional onnx inference session. A larger scale refactoring may be required to fix this.

Testing: I ran this on local with a bert model and was able to get to a numeric mismatch.

Full list of changes:

  1. Refactor input generation for ONNX models:

    • Split get_node_shape_from_dim_param_dict() out of generate_input_from_node() for better modularity.
    • Improve docstrings for better function descriptions.
  2. Enhance type annotations:

    • Add return type hints to construct_inputs() and get_sample_inputs_for_onnx_model().
    • Use TestTensors type consistently for input/output tensors.
  3. Fix NLP(Bert)-specific input generation:

    • Override construct_inputs() in NLP models to handle 'token_type_ids' correctly.