biubug6 / Pytorch_Retinaface

Retinaface get 80.99% in widerface hard val using mobilenet0.25.
MIT License
2.6k stars 764 forks source link

onnx 推理验证 #132

Open Single430 opened 3 years ago

Single430 commented 3 years ago

@biubug6

您好,最近看了您的项目,在过程中遇到一些问题,测试了很多,依旧没有解决

在转成onnx后,我验证了 onnx 的推理结果,发现有些不一样的地方

xx

可以看到box不准,landmark 应该也有问题,由于onnx输入的是固定 320x320 的图像,怀疑是对结果比例回放不对,但是测试了许多次依旧没有解决

onnx推理代码如下

#! /bin/python
# -*- coding: utf-8 -*-

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import warnings

import cv2
import onnx
import torch
import numpy as np
import onnxruntime
from layers.functions.prior_box import PriorBox

from data import cfg_mnet as cfg

warnings.filterwarnings("ignore")

class ONNXModel(object):
    def __init__(self, onnx_path):
        """
        :param onnx_path:
        """
        self.onnx_session = onnxruntime.InferenceSession(onnx_path)
        self.input_name = self.get_input_name(self.onnx_session)
        self.output_name = self.get_output_name(self.onnx_session)
        print("input_name:{}".format(self.input_name))
        print("output_name:{}".format(self.output_name))

    def get_output_name(self, onnx_session):
        """
        output_name = onnx_session.get_outputs()[0].name
        :param onnx_session:
        :return:
        """
        output_name = []
        for node in onnx_session.get_outputs():
            output_name.append(node.name)
        return output_name

    def get_input_name(self, onnx_session):
        """
        input_name = onnx_session.get_inputs()[0].name
        :param onnx_session:
        :return:
        """
        input_name = []
        for node in onnx_session.get_inputs():
            input_name.append(node.name)
        return input_name

    def get_input_feed(self, input_name, image_tensor):
        """
        input_feed={self.input_name: image_tensor}
        :param input_name:
        :param image_tensor:
        :return:
        """
        input_feed = {}
        for name in input_name:
            input_feed[name] = image_tensor
        return input_feed

    def forward(self, image_tensor):
        '''
        image_tensor = image.transpose(2, 0, 1)
        image_tensor = image_tensor[np.newaxis, :]
        onnx_session.run([output_name], {input_name: x})
        :param image_tensor:
        :return:
        '''
        # 输入数据的类型必须与模型一致,以下三种写法都是可以的
        # scores, boxes = self.onnx_session.run(None, {self.input_name: image_tensor})
        # scores, boxes = self.onnx_session.run(self.output_name, input_feed={self.input_name: image_tensor})
        input_feed = self.get_input_feed(self.input_name, image_tensor)
        loc, conf, landms = self.onnx_session.run(self.output_name, input_feed=input_feed)
        return loc, conf, landms

def get_input_img(path="./img/sample.jpg"):
    image_path = path
    img_raw = cv2.imread(image_path, cv2.IMREAD_COLOR)
    _img = np.float32(img_raw)

    # testing scale
    x_size = 320
    y_size = 320
    im_shape = _img.shape

    resize = x_size / max(im_shape)
    resize_x = round(float(x_size) / float(im_shape[1]), 2)
    resize_y = round(float(y_size) / float(im_shape[0]), 3)
    # prevent bigger axis from being more than max_size:
    # scale = x_size / max(im_shape)
    # print(scale)

    print(resize, im_shape)
    # _img = cv2.resize(_img, None, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR)
    # img = np.zeros((x_size, y_size, im_shape[2]), dtype=np.float32)
    # img[:_img.shape[0], :_img.shape[1]] = _img

    img = cv2.resize(_img, None, None, fx=resize_x, fy=resize_y, interpolation=cv2.INTER_LINEAR)

    img -= (104, 117, 123)
    img = img.transpose(2, 0, 1)
    img = img[np.newaxis, :]

    return img_raw, img, resize

def decode(loc, priors, variances):
    """Decode locations from predictions using priors to undo
    the encoding we did for offset regression at train time.
    Args:
        loc (tensor): location predictions for loc layers,
            Shape: [num_priors,4]
        priors (tensor): Prior boxes in center-offset form.
            Shape: [num_priors,4].
        variances: (list[float]) Variances of priorboxes
    Return:
        decoded bounding box predictions
    """

    boxes = torch.cat((
        priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
        priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
    boxes[:, :2] -= boxes[:, 2:] / 2
    boxes[:, 2:] += boxes[:, :2]
    return boxes

def decode_landm(pre, priors, variances):
    """Decode landm from predictions using priors to undo
    the encoding we did for offset regression at train time.
    Args:
        pre (tensor): landm predictions for loc layers,
            Shape: [num_priors,10]
        priors (tensor): Prior boxes in center-offset form.
            Shape: [num_priors,4].
        variances: (list[float]) Variances of priorboxes
    Return:
        decoded landm predictions
    """
    landms = torch.cat((priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:],
                        priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:],
                        priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:],
                        priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:],
                        priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:],
                        ), dim=1)
    return landms

def py_cpu_nms(dets, thresh):
    """Pure Python NMS baseline."""
    x1 = dets[:, 0]
    y1 = dets[:, 1]
    x2 = dets[:, 2]
    y2 = dets[:, 3]
    scores = dets[:, 4]

    areas = (x2 - x1 + 1) * (y2 - y1 + 1)
    order = scores.argsort()[::-1]

    keep = []
    while order.size > 0:
        i = order[0]
        keep.append(i)
        xx1 = np.maximum(x1[i], x1[order[1:]])
        yy1 = np.maximum(y1[i], y1[order[1:]])
        xx2 = np.minimum(x2[i], x2[order[1:]])
        yy2 = np.minimum(y2[i], y2[order[1:]])

        w = np.maximum(0.0, xx2 - xx1 + 1)
        h = np.maximum(0.0, yy2 - yy1 + 1)
        inter = w * h
        ovr = inter / (areas[i] + areas[order[1:]] - inter)

        inds = np.where(ovr <= thresh)[0]
        order = order[inds + 1]

    return keep

worker = ONNXModel("./faceDetector.onnx")

img_raw, img, resize = get_input_img("./img/sample.jpg")

loc, conf, landms = worker.forward(img)
loc = torch.Tensor(loc)
conf = torch.Tensor(conf)
landms = torch.Tensor(landms)
# torch.Size([1, 59500, 4]) torch.Size([1, 59500, 2]) torch.Size([1, 59500, 10])
print(loc.shape, conf.shape, landms.shape)

scale = torch.Tensor([img_raw.shape[1], img_raw.shape[0], img_raw.shape[1], img_raw.shape[0]])
scale1 = torch.Tensor([img_raw.shape[1], img_raw.shape[0], img_raw.shape[1], img_raw.shape[0],
                       img_raw.shape[1], img_raw.shape[0], img_raw.shape[1], img_raw.shape[0],
                       img_raw.shape[1], img_raw.shape[0]])
# scale = torch.Tensor([320]*4)
# scale1 = torch.Tensor([320]*10)

# 获得分数
scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
# 获得boxes
priorbox = PriorBox(cfg, image_size=(320, 320))
priors = priorbox.forward()
prior_data = priors.data

boxes = decode(loc.data.squeeze(0), prior_data, cfg['variance'])
landms = decode_landm(landms.data.squeeze(0), prior_data, cfg['variance'])

boxes = (boxes * scale) / 1
boxes = boxes.cpu().numpy()

landms = landms * scale1 / 1
landms = landms.cpu().numpy()

# ignore low scores
inds = np.where(scores > 0.6)[0]
boxes = boxes[inds]
landms = landms[inds]
scores = scores[inds]

# keep top-K before NMS
order = scores.argsort()[::-1][:5000]
boxes = boxes[order]
landms = landms[order]
scores = scores[order]

# do NMS
dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
keep = py_cpu_nms(dets, 0.4)
# keep = nms(dets, args.nms_threshold,force_cpu=args.cpu)
dets = dets[keep, :]
landms = landms[keep]

# keep top-K faster NMS
dets = dets[:750, :]
landms = landms[:750, :]

dets = np.concatenate((dets, landms), axis=1)

print(dets.shape)
count = 1
for b in dets:
    # if b[4] < 0.6:
    #     continue
    text = f"{count}-{b[4]:.3f}"
    b = list(map(int, b))

    print(count, b)
    count += 1
    cv2.rectangle(img_raw, (int(b[0]), int(b[1])), (int(b[2]), int(b[3])), (0, 0, 255), 2)
    cx = b[0]
    cy = b[1] + 12
    cv2.putText(img_raw, text, (cx, cy), cv2.FONT_HERSHEY_DUPLEX, 0.5, (255, 255, 255))

    # landms
    cv2.circle(img_raw, (b[5], b[6]), 1, (0, 0, 255), 4)
    cv2.circle(img_raw, (b[7], b[8]), 1, (0, 255, 255), 4)
    cv2.circle(img_raw, (b[9], b[10]), 1, (255, 0, 255), 4)
    cv2.circle(img_raw, (b[11], b[12]), 1, (0, 255, 0), 4)
    cv2.circle(img_raw, (b[13], b[14]), 1, (255, 0, 0), 4)
# save image
name = "test.jpg"
cv2.imwrite(name, img_raw)
zhouyongxyz commented 3 years ago

我遇到了同样的问题,使用opset 11 转onnx -> tensorrt 之后,推理的结果,人脸边界框不准确,关键点也不对。这个啥情况呢?

imistyrain commented 3 years ago

怀疑是缩放那出了问题

resize_x = round(float(x_size) / float(im_shape[1]), 2)
resize_y = round(float(y_size) / float(im_shape[0]), 3)

改成

resize_x = round(float(x_size) / float(im_shape[1]), 6)
resize_y = round(float(y_size) / float(im_shape[0]), 6)
QuantumLiu commented 3 years ago

The inference sample code detect.py is not clear at all.It seems like does not resize input? It's hard to write own inference code.

QuantumLiu commented 3 years ago

The inference sample code detect.py is not clear at all.It seems like does not resize input? It's hard to write own inference code.

It just doing torch.Tensor to np.array and to torch.Tensor again and agin.

lzhx171 commented 3 years ago

pytorch版本宽高是同比例缩放的,是否是因为你都缩放成320*320的原因,导致人脸变形检测不准