Closed ajithvcoder closed 1 year ago
Only the shape of the final output is broken because of the Reshape
just before the final output. Only the axis
of the last Concat
needs to be adjusted in JSON.
pip install onnx2tf==1.16.5
onnx2tf \
-i d3net_dnn_double_44.onnx \
-kat input \
-cotof
Pad
StridedSlice
since the TensorFlow runtime does not support padding of negative numbers.@PINTO0309 thanks a lot for supporting this. Could u kindly provide the param json to match the output because its giving (1, 1, 2049, 44, 8) even if i use permute it would be only (1, 1, 8, 2049, 44) . I dont know how to bring it to (1, 4, 2, 2049, 44) kindly advise me.
How ever during post processing i am able to match the output with onnx but it through python
import tensorflow as tf
import time
import onnxruntime as ort
import numpy as np
ort_session = ort.InferenceSession('./d3net_dnn_double_44.onnx', providers=["CPUExecutionProvider"])
outputs_onnx = ort_session.run(
None,
{'input': np.ones((1, 1, 2, 2049, 44)).astype(np.float32)}
)
print(outputs_onnx[0].shape)
# Load TFLite model
interpreter = tf.lite.Interpreter(model_path="./saved_model/d3net_dnn_double_44_float32.tflite")
interpreter.allocate_tensors()
tensor_shape = (1, 1, 2, 2049, 44)
input_data = {'waveform': tf.ones(tensor_shape, dtype=tf.float32) }
# Load and preprocess
input_details = interpreter.get_input_details()
input_shape = input_details[0]['shape']
# Run inference
interpreter.set_tensor(input_details[0]['index'], input_data["waveform"])
separate_time = time.time()
interpreter.invoke()
print("Done! {:.3f} s".format(time.time() - separate_time))
output_details = interpreter.get_output_details()
output_data = interpreter.get_tensor(output_details[0]['index'])
output_data = np.transpose(output_data, (0, 1,4, 2, 3))
output_data = np.squeeze(output_data,axis=1)
output_data = output_data.reshape(1,4,2 ,2049,44)
print((output_data.shape))
#print(outputs_onnx[0]==output_data)
np.testing.assert_allclose(outputs_onnx[0], output_data, rtol=1e-03, atol=1e-01)
print("All match")
The modifiable parameters of each operation are inputs
and outputs
, attributes
and constant
. In this case, we only need to correct the axis (attributes)
of concat
. And if you want to change the order of the output, simply apply a transpose to output
.
https://github.com/PINTO0309/onnx2tf#parameter-replacement
replace_d3net.json
{
"format_version": 1,
"operations": [
{
"op_name": "Concat_16132",
"param_target": "attributes",
"param_name": "axis",
"values": 1
},
{
"op_name": "Concat_16132",
"param_target": "outputs",
"param_name": "output",
"post_process_transpose_perm": [0,1,4,2,3]
}
]
}
onnx2tf \
-i d3net_dnn_double_44.onnx \
-kat input \
-cotof \
-prf replace_d3net.json
okay got it earlier it was concatenating in axis 4 so it was giving (1, 1, 2049, 44, 8) now after concatenating in axis 1 its (1, 4, 2049, 44, 2) , then while transposing it came to (1, 4, 2, 2049, 44) Thanks again for helping.
I am slowly improving this tool and will gradually eliminate the need for JSON for automatically identifiable patterns.
This time the pattern seems to be able to auto-correct.
Thanks but its already a great tool you are helping a lot of people
Issue Type
Others
OS
Linux
onnx2tf version number
1.16.2
onnx version number
1.14.0
onnxruntime version number
1.15.1
onnxsim (onnx_simplifier) version number
0.4.33
tensorflow version number
2.13.0
Download URL for ONNX
https://drive.google.com/file/d/18rrEyoVGxM2mQjGLWHDoLCOrhdxx0sFf/view?usp=sharing
Parameter Replacement JSON
Description
onnx2tf -i d3net_dnn_double_44.onnx
Paddings must be non-negative for '{{node tf.pad/Pad}} = Pad[T=DT_FLOAT, Tpaddings=DT_INT32](Placeholder, tf.pad/Pad/paddings)' with input shapes: [1,176,12,2], [4,2] and with computed input tensors: input[1] = <[0 0][0 0][0 -1][0 0]>.
Call arguments received by layer "tf.pad" (type TFOpLambda): • tensor=tf.Tensor(shape=(1, 176, 12, 2), dtype=float32) • paddings=tf.Tensor(shape=(4, 2), dtype=int32) • mode='constant' • constant_values=array(0., dtype=float32) • name='Pad_1913'