Closed renxida closed 1 month ago
Note: this causes each nlp model to create an additional onnx inference session. A larger scale refactoring may be required to fix this.
Testing: I ran this on local with a bert model and was able to get to a numeric mismatch.
Full list of changes:
Refactor input generation for ONNX models:
get_node_shape_from_dim_param_dict()
generate_input_from_node()
Enhance type annotations:
construct_inputs()
get_sample_inputs_for_onnx_model()
TestTensors
Fix NLP(Bert)-specific input generation:
Note: this causes each nlp model to create an additional onnx inference session. A larger scale refactoring may be required to fix this.
Testing: I ran this on local with a bert model and was able to get to a numeric mismatch.
Full list of changes:
Refactor input generation for ONNX models:
get_node_shape_from_dim_param_dict()
out ofgenerate_input_from_node()
for better modularity.Enhance type annotations:
construct_inputs()
andget_sample_inputs_for_onnx_model()
.TestTensors
type consistently for input/output tensors.Fix NLP(Bert)-specific input generation:
construct_inputs()
in NLP models to handle 'token_type_ids' correctly.