Open atmccarthy opened 3 years ago
Thanks for making this the request @atmccarthy . So actually if you have this converted to run on Myriad X, it is possible to run it in a non-YOLO parsing way such that the metadata results are sent to the host un-changed and then can be parsed there.
So if this is converted already, this technique can be used instead. I'll see if I can find a reference of doing this.
Found one: https://github.com/luxonis/depthai-experiments/tree/master/gen2-efficientDet
So this example uses host-side decoding of the neural data. So DepthAI does not need to know anything about the model being run - DepthAI simply spits out the neural outputs to the host.
So this is a great way to start for getting YOLOX running if you'd like to get it going now as then (1) we don't block you and (2) we'd need to do it anyway before porting YOLOX parsing directly into the Myriad X, as it's a lot easier to debug such parsings on the host first.
Thoughts?
Thanks, Brandon
Thanks @Luxonis-Brandon. I used the example you shared and the YOLOX OpenVINO demo code to build a working solution. Please see the code at the end of this comment, and the attached zip containing a blob for the Myriad. There are some major drawbacks with this approach:
As for the performance impact, i get ~15fps running the attached code on my 8700k v.s. ~28 for yolo4 running via the YoloDetectionNetwork node. It's going to be completely unworkable on the ESP, so i haven't tried it.
On the positive side, it does seem to be more accurate than yolov4 in practice, e.g. if i point the camera at my face, then yolo4 says i'm a dog (i had headphones on at the time, which probably didn't help...), v.s YOLOX says i'm a person. YOLOX also detects the picture of my dog that I have sitting on my desk as a dog whereas yolo4 oscillates between cat and dog. I did notice that the YOLOX bounding boxes were a bit janky for really large objects. Anyway, this very scientific test proves that you should look into it further :D
from pathlib import Path
import numpy as np
import cv2
import depthai as dai
import time
def preproc(image, input_size, mean, std, swap=(2, 0, 1)):
if len(image.shape) == 3:
padded_img = np.ones((input_size[0], input_size[1], 3)) * 114.0
else:
padded_img = np.ones(input_size) * 114.0
img = np.array(image)
r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
resized_img = cv2.resize(
img,
(int(img.shape[1] * r), int(img.shape[0] * r)),
interpolation=cv2.INTER_LINEAR,
).astype(np.float32)
padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
padded_img = padded_img[:, :, ::-1]
padded_img /= 255.0
if mean is not None:
padded_img -= mean
if std is not None:
padded_img /= std
padded_img = padded_img.transpose(swap)
padded_img = np.ascontiguousarray(padded_img, dtype=np.float16)
return padded_img, r
def nms(boxes, scores, nms_thr):
"""Single class NMS implemented in Numpy."""
x1 = boxes[:, 0]
y1 = boxes[:, 1]
x2 = boxes[:, 2]
y2 = boxes[:, 3]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)
inds = np.where(ovr <= nms_thr)[0]
order = order[inds + 1]
return keep
def multiclass_nms(boxes, scores, nms_thr, score_thr):
"""Multiclass NMS implemented in Numpy"""
final_dets = []
num_classes = scores.shape[1]
for cls_ind in range(num_classes):
cls_scores = scores[:, cls_ind]
valid_score_mask = cls_scores > score_thr
if valid_score_mask.sum() == 0:
continue
else:
valid_scores = cls_scores[valid_score_mask]
valid_boxes = boxes[valid_score_mask]
keep = nms(valid_boxes, valid_scores, nms_thr)
if len(keep) > 0:
cls_inds = np.ones((len(keep), 1)) * cls_ind
dets = np.concatenate(
[valid_boxes[keep], valid_scores[keep, None], cls_inds], 1
)
final_dets.append(dets)
if len(final_dets) == 0:
return None
return np.concatenate(final_dets, 0)
def demo_postprocess(outputs, img_size, p6=False):
grids = []
expanded_strides = []
if not p6:
strides = [8, 16, 32]
else:
strides = [8, 16, 32, 64]
hsizes = [img_size[0] // stride for stride in strides]
wsizes = [img_size[1] // stride for stride in strides]
for hsize, wsize, stride in zip(hsizes, wsizes, strides):
xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
grids.append(grid)
shape = grid.shape[:2]
expanded_strides.append(np.full((*shape, 1), stride))
grids = np.concatenate(grids, 1)
expanded_strides = np.concatenate(expanded_strides, 1)
outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides
return outputs
SHAPE = 416
labelMap = [
"person", "bicycle", "car", "motorbike", "aeroplane", "bus", "train",
"truck", "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench",
"bird", "cat", "dog", "horse", "sheep", "cow", "elephant",
"bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie",
"suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat",
"baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup",
"fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich",
"orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake",
"chair", "sofa", "pottedplant", "bed", "diningtable", "toilet", "tvmonitor",
"laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven",
"toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors",
"teddy bear", "hair drier", "toothbrush"
]
p = dai.Pipeline()
p.setOpenVINOVersion(dai.OpenVINO.VERSION_2021_3)
class FPSHandler:
def __init__(self, cap=None):
self.timestamp = time.time()
self.start = time.time()
self.frame_cnt = 0
def next_iter(self):
self.timestamp = time.time()
self.frame_cnt += 1
def fps(self):
return self.frame_cnt / (self.timestamp - self.start)
camRgb = p.createColorCamera()
camRgb.setPreviewSize(SHAPE, SHAPE)
camRgb.setResolution(dai.ColorCameraProperties.SensorResolution.THE_1080_P)
camRgb.setInterleaved(False)
camRgb.setColorOrder(dai.ColorCameraProperties.ColorOrder.BGR)
nn = p.createNeuralNetwork()
nn.setBlobPath(str(Path("yolox_tiny.blob").resolve().absolute()))
nn.setNumInferenceThreads(2)
nn.input.setBlocking(True)
# Send rgb frames to the host
rgb_xout = p.createXLinkOut()
rgb_xout.setStreamName("rgb")
camRgb.preview.link(rgb_xout.input)
# Send converted frames from the host to the NN
xinFrame = p.createXLinkIn()
xinFrame.setStreamName("inFrame")
xinFrame.out.link(nn.input)
# Send bounding boxes from the NN to the host via XLink
nn_xout = p.createXLinkOut()
nn_xout.setStreamName("nn")
nn.out.link(nn_xout.input)
# Pipeline is defined, now we can connect to the device
with dai.Device(p) as device:
qRgb = device.getOutputQueue(name="rgb", maxSize=4, blocking=True)
qIn = device.getInputQueue("inFrame", maxSize=4, blocking=True)
qNn = device.getOutputQueue(name="nn", maxSize=4, blocking=True)
fps = FPSHandler()
while True:
inRgb = qRgb.get()
frame = inRgb.getCvFrame()
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
image, ratio = preproc(frame, (SHAPE, SHAPE), mean, std)
# NOTE: The model expects an FP16 input image, but ImgFrame accepts a list of ints only. I work around this by
# spreading the FP16 across two ints
image = list(image.tobytes())
dai_frame = dai.ImgFrame()
dai_frame.setHeight(SHAPE)
dai_frame.setWidth(SHAPE)
dai_frame.setData(image)
qIn.send(dai_frame)
in_nn = qNn.tryGet()
if in_nn is not None:
fps.next_iter()
cv2.putText(frame, "Fps: {:.2f}".format(fps.fps()), (2, SHAPE - 4), cv2.FONT_HERSHEY_TRIPLEX, 0.4, color=(255, 255, 255))
data = np.array(in_nn.getLayerFp16('output')).reshape(1, 3549, 85)
predictions = demo_postprocess(data, (SHAPE, SHAPE), p6=False)[0]
boxes = predictions[:, :4]
scores = predictions[:, 4, None] * predictions[:, 5:]
boxes_xyxy = np.ones_like(boxes)
boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2] / 2.
boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3] / 2.
boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2] / 2.
boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3] / 2.
dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1)
if dets is not None:
final_boxes = dets[:, :4]
final_scores, final_cls_inds = dets[:, 4], dets[:, 5]
for i in range(len(final_boxes)):
bbox = final_boxes[i]
score = final_scores[i]
class_name = labelMap[int(final_cls_inds[i])]
if score >= 0.1:
# Limit the bounding box to 0..SHAPE
bbox[bbox > SHAPE] = 1
bbox[bbox < 0] = 0
xy_min = (int(bbox[0]), int(bbox[1]))
xy_max = (int(bbox[2]), int(bbox[3]))
# Display detection's BB, label and confidence on the frame
cv2.rectangle(frame, xy_min , xy_max, (255, 0, 0), 2)
cv2.putText(frame, class_name, (xy_min[0] + 10, xy_min[1] + 20), cv2.FONT_HERSHEY_TRIPLEX, 0.5, 255)
cv2.putText(frame, f"{int(score * 100)}%", (xy_min[0] + 10, xy_min[1] + 40), cv2.FONT_HERSHEY_TRIPLEX, 0.5, 255)
cv2.imshow("rgb", frame)
if cv2.waitKey(1) == ord('q'):
break
Thanks @atmccarthy for this reference, this is great! Regarding normalization - we had a discussion about a different network that had the same issue (with nms on model) - you can take a look here for the details. Overall, I think this should be handled in the model itself, but for now, I'll check if I can make it work as-is, and add this code to our experiments repo, which can be later optimized.
YOLOX (OpenVINO IR and Myriad Inference Blob, ONNX, etc...)
https://github.com/PINTO0309/PINTO_model_zoo/tree/main/132_YOLOX
This is an old model that was converted within a day after YOLOX was officially released. Please ignore this if it is not helpful as I am not aware of the flow of the discussion.
I gave your snippet a try and I'm a bit confused about the results - for smaller objects, like bottle or person on an image, it works correctly, but for larger ones (like myself in front of the camera), the bounding box is incorrect
I added the copied code to gen2_yolox
branch but I think that the model should be modified first (with --mean_values
and --scale_values
flags)
Thanks very much for testing this out @VanDavv. The problem occurs when one of the coordinates of the bounding box is larger than the image. The problem is in this line:
bbox[bbox > SHAPE] = 1
It should have been:
bbox[bbox > SHAPE - 1] = SHAPE - 1
i.e. the coordinate was being rounded down to 1 instead of 415!
I'll create a PR on your gen2_yolox branch.
It looks like i don't have permission to push branches to that repo. I've attached a patch to this comment. I made a few other changes:
Thanks @atmccarthy! I'll update the PR and circle back, much appreciated your help!
@atmccarthy the patch works, thanks! I'll add a PR with your changes and then, thanks to your README.md, I'll try to compile this model with preprocessing steps included
@atmccarthy I tried to use the latest release of YoloX - 0.1.1rc0
- since they changed the handling of mean and std values
Breaking changes We remove the normalization operation like -mean/std. This will make the old weights incompatible.
I created a notebook that performs ONNX -> Blob conversion here and wanted to give it a try, but it's not working yet (also created gen2_yolox_no_preprocessing
branch to reproduce)
The input size has increased with the latest version from 416x416 to 640x640. Also, for some reason, after converting the model to blob, the output is not [1, 3549, 85]
but [1, 8400, 85]
- although when checking in https://netron.app/, the output layer in ONNX is unchanged
I won't have much time to dig into it in the following days, so will leave it here in case someone else would like to proceed further
The input size has increased with the latest version from 416x416 to 640x640. Also, for some reason, after converting the model to blob, the output is not [1, 3549, 85] but [1, 8400, 85] - although when checking in https://netron.app/, the output layer in ONNX is unchanged
This depends on what kind of yolox model you choose. The nano
and tiny
version have 416x416
input res and therefore output a [1,3549,85]
tensor. All bigger models have 640x640
input res and output a [1,8400,85]
tensor. So the input size directly correlates with the second dimension of the output size. I assume this is because of their "anchor-free" approach. The third dimension contains all class probabilities (80) and the bounding box and score information.
All onnx models of yolox can be found in their repository.
I won't have much time to dig into it in the following days, so will leave it here in case someone else would like to proceed further
I will very soon have a look into YOLOX as soon as my OAK-D Lite device arrives
@VanDavv @atmccarthy @Luxonis-Brandon Do you know whether it is possible to do the NMS thresholding in hardware like for YoloV3. I could not find the code for the nms thresholding and, therefore, I don't know the format that it expects ...
Hey @JojoDevel ,
AFAIK the NMS on device is only done for YoloV3, V4, and V5 (with the 2.14 version). We do not have a direct support for NMS for all the models right now. Since YOLOX is anchorless, you won't be able to mimic the exact outputs of YoloV3,V4 and V5. The Mobilenet node also expects the NN to output a 2D array (n rows x 7 columns) I think, so some part of this post-processing would have to be done on the host any way. If there will be more demand, we'll look into adding native support for YoloX, just like we did with YoloV5 recently :)
Start with the
why
:The tiny variant of YOLOX has a number of advantages over yolov4-tiny:
There is also a nano variant that would be useful in low power applications, or in situations where you want to run multiple models.
Move to the
what
:It would be great if there was first class support for YOLOX in the DepthAI Python and C++ APIs, e.g. by adding support for YOLOX to the existing YoloDetectionNetwork pipeline node. YOLOX is anchorless, so the existing device side decoding probably won't work. I tried it out anyway (by following the OpenVINO instructions on the YOLOX github page, and then compiling to a blob for the myriad), but i get the following error:
[14442C10F1C14ED000] [50.022] [system] [critical] Fatal error. Please report to developers. Log: 'PlgDetectionParser' '109'