Closed ingura closed 1 year ago
I cannot access your Google Drive files. Please grant browse permission.
Sorry about that. Try now
I have confirmed that I can convert if I use this JSON. However, I found some bugs in the Concat
parameter replacement process, so wait until the next release.
https://github.com/PINTO0309/onnx2tf/blob/main/json_samples/replace_rtmdet_tiny.json
Note that the use of multi-class NMS is not recommended. rtmDet-tiny-res640-fp32_float32.tflite.zip
That's great! Is there a parameter I can specify while using onnx2tf in order to skip NMS? The conversion including the NMS worked fine on yolov7.
On Mon, Feb 27, 2023, 7:14 PM Katsuya Hyodo @.***> wrote:
I have confirmed that I can convert if I use this JSON. However, I found some bugs in the Concat parameter replacement process, so wait until the next release.
https://github.com/PINTO0309/onnx2tf/blob/main/json_samples/replace_rtmdet_tiny.json
Note that the use of multi-class NMS is not recommended. rtmDet-tiny-res640-fp32_float32.tflite.zip https://github.com/PINTO0309/onnx2tf/files/10846216/rtmDet-tiny-res640-fp32_float32.tflite.zip
[image: image] https://user-images.githubusercontent.com/33194443/221744311-ac73fa8d-a70b-4b81-86e4-6e63b9bbbadd.png
— Reply to this email directly, view it on GitHub https://github.com/PINTO0309/onnx2tf/issues/210#issuecomment-1447509029, or unsubscribe https://github.com/notifications/unsubscribe-auth/AYV4FUW6FJAJ3JP5UDWM57TWZVURNANCNFSM6AAAAAAVKETIZA . You are receiving this because you authored the thread.Message ID: @.***>
Is there a parameter I can specify while using onnx2tf in order to skip NMS?
No. If you do not want multiple NMSs to be generated, you need to include ReduceMax and TopK in the post-processing to limit the number of classes to one. On the PyTorch side. The problem is not with the NMS itself; if the number of classes input to the NMS is greater than 1, the NMS is generated separated by the number of classes. This is a TFLite specification.
That is included in the model linked above
Your post-processing writing style is not optimized enough and still redundant.
Good to know
I tried the converted model and the output tensor initialization is better now but it is fixed at 5 while the .onnx model has its output fixed at top 30. The bounding box tensor has the shape [30, 5]. 30 is the number of boxes while 5 is the number of parameters for each box .
The sample JSON is a very random JSON that was checked only to make sure that the conversion would succeed. I'm working on other things right now and don't have time to adjust JSON, so please try the subsequent TopK and other axis replacements yourself.
Thanks, i'll check it out
On Mon, Feb 27, 2023, 8:02 PM Katsuya Hyodo @.***> wrote:
The sample JSON is a very random JSON that was checked only to make sure that the conversion would succeed. I'm working on other things right now and don't have time to adjust JSON, so please try the subsequent TopK and other axis replacements yourself.
— Reply to this email directly, view it on GitHub https://github.com/PINTO0309/onnx2tf/issues/210#issuecomment-1447541589, or unsubscribe https://github.com/notifications/unsubscribe-auth/AYV4FUQTLRHIXFIQAUMJ7EDWZV2F7ANCNFSM6AAAAAAVKETIZA . You are receiving this because you authored the thread.Message ID: @.***>
Thanks Pinto, now it works on the java side however there are still some bugs hiding in the code:
import cv2
import random
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
usingRrmTiny = False
usingRrmTiny = True
tfLiteTestDir = "your path to the model folder"
dataDir = "your path to the test data"
dataPath = dataDir + "1.jpg"
# Load the TFLite model and allocate tensors.
if usingRrmTiny :
interpreter = tf.lite.Interpreter(model_path=tfLiteTestDir + "Models\\pinto2-rtmDet-tiny-res640-fp32_float32.tflite")
#Name of the classes according to class indices.
names = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
'hair drier', 'toothbrush']
def scaleAndFill(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleup=True, stride=32):
# Resize and pad image while meeting stride-multiple constraints
shape = im.shape[:2] # current shape [height, width]
if isinstance(new_shape, int):
new_shape = (new_shape, new_shape)
# Scale ratio (new / old)
scale = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
if not scaleup: # only scale down, do not scale up (for better val mAP)
scale = min(scale, 1.0)
# Compute padding
new_unpad = int(round(shape[1] * scale)), int(round(shape[0] * scale))
fillW, fillH = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # padding
if auto: # minimum rectangle
fillW, fillH = np.mod(fillW, stride), np.mod(fillH, stride) # padding
fillW /= 2 # divide padding into 2 sides
fillH /= 2
if shape[::-1] != new_unpad: # resize
im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(fillH - 0.1)), int(round(fillH + 0.1))
left, right = int(round(fillW - 0.1)), int(round(fillW + 0.1))
im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # fill border
return im, scale, (fillW, fillH)
#Load and preprocess the image.
img = cv2.imread(dataPath)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
image = img.copy()
if usingRrmTiny :
image, scale, fill = scaleAndFill(image,(640,640), auto=False)
image = np.expand_dims(image, 0)
image = np.ascontiguousarray(image)
im = image.astype(np.float32)
im /= 255
#Allocate tensors.
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]['index'], im)
interpreter.invoke()
# The function `get_tensor()` returns a copy of the tensor data.
# Use `tensor()` in order to get a pointer to the tensor.
outputBoxScoreData = interpreter.get_tensor(output_details[0]['index'])
outputClassData = interpreter.get_tensor(output_details[1]['index'])
## Visualize results
#Creating random colors for bounding box visualization.
colors = {name:[random.randint(0, 255) for _ in range(3)] for i,name in enumerate(names)}
image = img.copy()
i = 0
if usingRrmTiny:
for box in range(outputBoxScoreData[0].shape[0]):
name = names[int(outputClassData[0][box])]
score = round(float(outputBoxScoreData[0][box][4]),3)
xLeft = outputBoxScoreData[0][box][0]
yTop = outputBoxScoreData[0][box][1]
xRight = outputBoxScoreData[0][box][2]
yBottom = outputBoxScoreData[0][box][3]
if score > 0.35:
box = np.array([-xLeft,-yTop,xRight,yBottom])
box -= np.array(fill*2)
box /= scale
box = box.round().astype(np.int32).tolist()
color = colors[name]
name += ' '+str(score)
cv2.rectangle(image,box[:2],box[2:],color,1)
cv2.putText(image,name,(box[0], box[1] - 2),cv2.FONT_HERSHEY_SIMPLEX,0.75,[225, 255, 255],thickness=2)
i = i+1
print("{} confidence: {} , bbox:({},{})<>({},{})] , class :{}".format(i, score,round(float(xLeft)),round(float(yTop)),round(float(xRight)),round(float(yBottom)),name))
plt.imshow(image)
plt.title('TfLite Indications', fontweight ="bold")
plt.show()
onnx2tf \
-i rtmDet-tiny-res640-fp32.onnx \
-prf replace_rtmdet_tiny.json \
-onimc /Transpose_6_output_0 /Gather_6_output_0 /Transpose_7_output_0
The .onnx model performs fine but the converted model does not always replicate its performance. I'll take a look at the exposed debug information.
If I can rewrite its structure that would be amazing
Issue Type
Others
onnx2tf version number
1.5.40
onnx version number
1.12.0
tensorflow version number
2.10.1
Download URL for ONNX
https://drive.google.com/file/d/15UQMdyn-CCHzkXm75-WlptaJvqthDIU_/view?usp=sharing
Parameter Replacement JSON
Description
Hi there!
I have tried to convert the RTMDet ONNX model linked here:https://drive.google.com/file/d/15UQMdyn-CCHzkXm75-WlptaJvqthDIU_/view?usp=sharing. The model has a static output and includes NonMaxSuppression in its post processing phase. I got the model form here https://github.com/open-mmlab/mmdetection/tree/3.x/configs/rtmdet and it comes under Apache 2.0 license: https://github.com/open-mmlab/mmdetection/blob/master/LICENSE
When using the default conversion command:
onnx2tf -i rtmDet-tiny-res640-fp32.onnx -o models-NHWc-final/
I run into the following error:
If I enable onwdt it does convert the .onnx model to .tflite and the model works in python but it does not work while using the TfLite Java API for mobile devices. So the resulting .tflite model can not be deployed on a phone or tablet.
The problem I encounter is in the fact that even though onnx2tf generates a static output model (as it should ) after the tfLite interpreter loads the generated model, it reports the wrong output size. To replicate the issue you can use the following python code:
The output tensor with which the tfLite interpreter is initialized is of shape [1, 1, 5] as it were a dynamic output while after the inference the output size becomes static at [30, 5] as it should since the input model is static. This is not a problem in python where the output tensor does not need to be initialized ,however, it becomes a problem when I try to deploy it on a phone using the the Java or Cpp TfLite API. This happens because the API requires a correct output tensor initialization. When the model is loaded by the interpreter its output tensor becomes initialized automatically to the size that is reported into the converted .tflite model: initial_output_size = [1 ,1 ,5]. When the model is deployed I get a fatal exception as the correct output size is [30, 5].
I used the -onwdt option because otherwise the standard conversion would fail. However I wonder if I used the wrong onnx2tf conversion options. Please advise.