ultralytics / yolov5

YOLOv5 🚀 in PyTorch > ONNX > CoreML > TFLite
https://docs.ultralytics.com
GNU Affero General Public License v3.0
50.77k stars 16.35k forks source link

numerical differences after converting to coreml #571

Closed wmpauli closed 4 years ago

wmpauli commented 4 years ago

How big of a numerical difference in output would be acceptable, between pytorch output and coreml output.

Additional context

Thank you for sharing your code. I have trained my own model and it works well. I converted it to coreml and see some differences in behavior. In coreml version, i have to dial down the confidence threshold to get the same results. Though I doubt that this will generally a reliable solution.

I tried to debug this a bit. In pytorch, if I set a breakpoint in line 38 of models/yolo.py, I get this:

> x[0][0,0,:5,:5,0]
tensor(
[[ 1.98129, -0.47193, -2.31064, -0.30100, -0.27282],
[ 1.47366,  0.05325, -1.33939,  0.01538, -0.07743],
[ 1.69490, -0.06065, -0.44285, -1.38465, -0.33619],
[ 1.97999, -0.12810, -0.97429, -1.57267, -0.31991],
[ 2.14888, -0.06261, -2.44762, -1.80719, -0.12404]])

but when running the coreml model on the same image, I get:

> x[0][0,0,:5,:5,0]
tensor(
[[ 1.97070, -0.47119, -2.30078, -0.28369, -0.28760], 
[ 1.47168, 0.05270, -1.33301, 0.02713, -0.07922],
[ 1.69141, -0.05957, -0.45947, -1.36816, -0.34399],
[ 1.98047, -0.12585, -1.01172, -1.55957, -0.31982],
[ 2.13867, -0.05994, -2.44922, -1.78809, -0.12598]])

Is this in the realm of what one would expect? (i'm new to coreml). Do I need to make changes to scale or bias during conversion? I played around with that a bit, but the settings in the export script seem to be the best.

Is it generally expected that one has to play around with conf_thres and iou_thres after conversion?

github-actions[bot] commented 4 years ago

Hello @wmpauli, thank you for your interest in our work! Please visit our Custom Training Tutorial to get started, and see our Jupyter Notebook Open In Colab, Docker Image, and Google Cloud Quickstart Guide for example environments.

If this is a bug report, please provide screenshots and minimum viable code to reproduce your issue, otherwise we can not help you.

If this is a custom model or data training question, please note Ultralytics does not provide free personal support. As a leader in vision ML and AI, we do offer professional consulting, from simple expert advice up to delivery of fully customized, end-to-end production solutions for our clients, such as:

For more information please visit https://www.ultralytics.com.

wmpauli commented 4 years ago

I guess, this is not that unusual: https://developer.apple.com/forums/thread/82147

glenn-jocher commented 4 years ago

@wmpauli seems a bit higher than I'd expect. Results will depend on your quantization also.

wmpauli commented 4 years ago

thanks @glenn-jocher. I didn't do any quantization.

fnuabhimanyu commented 4 years ago

@wmpauli I have been facing the same issue, but in my case the error percent is lot higher. I am not sure where the error is. I want to double check my inference code, so can you please post your code which runs inference on the *.mlmodel file.

wmpauli commented 4 years ago

@Abhimanyu8713 , below is the code I use for evaluation. Hopefully it is useful. You will probably have to make some changes to the constants at the top of the script. I'm still not sure why the results are so different after conversion. I'm suspecting that it is something about image normalization, or some other transform that happens either in coreml or in pytorch.

"""

Usage: export PYTHONPATH="$PWD" && python models/eval_coreml.py

"""

from models.yolo import Detect
from utils.utils import scale_coords, non_max_suppression, xyxy2xywh, plot_one_box
import coremltools
from PIL import Image
import torch
import numpy as np
import random
import cv2
import os
import shutil

# CONSTANTS
COREML_MODEL = "weights/best.mlmodel"
IMAGE_FOLDER = "inference/images/val/"
OUT_FOLDER = "inference/out/coreml/"
SAVE_IMG = True
VIEW_IMG = False
SAVE_TXT = False
CAT_NAMES = ['open', 'closed']
COLORS = [[random.randint(0, 255) for _ in range(3)] for _ in range(len(CAT_NAMES))]
PATH = "./"
ANCHORS = ([116,90, 156,198, 373,326], [30,61, 62,45, 59,119], [10,13, 16,30, 33,23]) # from <model>.yml
IMG_SIZE = (640, 640)

# GLOBAL VARIABLES
nc = len(CAT_NAMES)
nl = len(ANCHORS)
na = len(ANCHORS[0]) // 2
no = nc + 5  # number of outputs per anchor
grid = [torch.zeros(1)] * nl  # init grid
a = torch.tensor(ANCHORS).float().view(nl, -1, 2)
anchor_grid = a.clone().view(nl, 1, -1, 1, 1, 2)
stride = [32, 16, 8] # check your model config
conf_thres = .3

def make_grid(nx=20, ny=20):
    yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
    return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()

def resize_image(source_image):
    background = Image.new('RGB', IMG_SIZE, "black")
    source_image.thumbnail(IMG_SIZE)
    (w, h) = source_image.size
    background.paste(source_image, (int((IMG_SIZE[0] - w) / 2), int((IMG_SIZE[1] - h) / 2 )))

    return background

def eval(file_name):   
    image = Image.open(os.path.join(IMAGE_FOLDER, file_name))

    image = resize_image(image)

    img = torch.zeros((1,3,IMG_SIZE[0],IMG_SIZE[1]))
    img[0, :, :, :] = torch.Tensor(np.array(image)).permute(2, 0, 1)
    im0 = np.array(image)

    predictions = model.predict({'images': image})

    z = []  # inference output
    x = []
    for pred in predictions:
        x.append(torch.Tensor(predictions[pred]))
    x.reverse()

    for i in range(nl):
        bs, _, ny, nx, _ = x[i].shape

        if grid[i].shape[2:4] != x[i].shape[2:4]:
            grid[i] = make_grid(nx, ny)

        y = x[i].sigmoid()
        y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + grid[i]) * stride[i]  # xy
        y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * anchor_grid[i]  # wh
        z.append(y.view(bs, -1, no))

    pred = (torch.cat(z, 1), x)[0]

    pred = non_max_suppression(pred, conf_thres, .5, classes=None, agnostic=False)

    # Process detections
    for i, det in enumerate(pred):  # detections per image
        p, s = "./", ""

        if det is not None and len(det):
            # Rescale boxes from img_size to im0 size
            det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()

            # Print results
            for c in det[:, -1].unique():
                n = (det[:, -1] == c).sum()  # detections per class
                s += '%g %ss, ' % (n, CAT_NAMES[int(c)])  # add to string

            # Write results
            for *xyxy, conf, cls in det:
                if SAVE_TXT:  # Write to file
                    xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()  # normalized xywh
                    with open(os.path.join(OUT_FOLDER, 'predictions.txt'), 'a') as f:
                        f.write(('%g ' * 5 + '\n') % (cls, *xywh))  # label format

                if SAVE_IMG or VIEW_IMG:  # Add bbox to image
                    label = '%s %.2f' % (CAT_NAMES[int(cls)], conf)
                    plot_one_box(xyxy, im0, label=label, color=COLORS[int(cls)], line_thickness=3)

    if SAVE_IMG:
        cv2.imwrite(os.path.join(OUT_FOLDER, file_name), im0)

def main():
    global model
    if os.path.exists(OUT_FOLDER):
        shutil.rmtree(OUT_FOLDER)
    os.makedirs(OUT_FOLDER)

    # Load the model
    model = coremltools.models.model.MLModel(COREML_MODEL)

    image_files = os.listdir(IMAGE_FOLDER)

    for i_f in image_files:
        eval(i_f)

if __name__ == "__main__":
    # execute only if run as a script
    main()
dlawrences commented 4 years ago

Noticed this is the case for me as well. Has any of you got to any solution at all?

I know there are compatibility issues between PyTorch/TensorFlow upsample like operations that lead to differences at times (some info here), but as per my checks, yolov5 should be ok (uses upsample with nearest_neighbor).

@glenn-jocher : did you benchmark your YOLOv5 exported model against the results you get with the trained checkpoints using the detect.py/test.py scripts? If so, did you get any big discrepancies?

glenn-jocher commented 4 years ago

@dlawrences we've benchmarked FP16/32 changes as negligible during the FP16 update, other than that no, we do not test exported models using test.py. When quantizing in CoreML you can clearly see progressively worse deterioration in the anecdotal results in iDetection at higher quantization levels, using both kmeans and linear methods.

dlawrences commented 4 years ago

@glenn-jocher thanks, Glenn. for clarity, are you still relying on PyTorch > ONNX > CoreML conversion path, right?

Overall, I think it is probably related to this bit: https://github.com/apple/coremltools/issues/831

I will dig through the convert API tomorrow to see if there's anything we can force to avoid this.

Cheers

glenn-jocher commented 4 years ago

@dlawrences yes, but this is a coreml step, so may or may not depend on the route the model took to get there.

wmpauli commented 4 years ago

@dlawrences , it is my understanding that the conversion is actually PyTorch > traced Pytorch > CoreML, i.e. w/o onnx. Also, I don't get the error msg mentioned in https://github.com/apple/coremltools/issues/831.

fnuabhimanyu commented 4 years ago

Hi @wmpauli thanks for sharing the code. My code is almost similar to yours. The only thing which I have not included is

x.reverse()

I didn't get why we need to reverse the prediction array. Can you shed some light on this? Also for the CoreML model did you use model.model[-1].export = True in models/export.py.

dlawrences commented 4 years ago

Updates:

I now get way better results that before on device using YOLOv5s just by upgrading to coremltools==4.0b2.

I have not benchmarked just yet the same footage against the detect.py results, will do in the following days.

glenn-jocher commented 4 years ago

@dlawrences I wonder if we should add the export dependencies (onnx, coremltools==4.0b2) to requirements.txt. I haven't so far because I suspect the vast majority of users don't need them. The way I handled this for pycocotools (for computing official COCO mAP) was to add it to requirements.txt but comment it out: https://github.com/ultralytics/yolov5/blob/66744a0df1935f94090bbaf632f4f070051f5502/requirements.txt#L1-L14

Other repos, like pytorch lightning have a requirements folder with different requirements.txt files added by use case, so that's another option (i.e. requirements/base.txt and requirements/export.txt).

dlawrences commented 4 years ago

@glenn-jocher I think you should use coremltools=4.0b2 in your requirements.txt going forward. It certainly is faster to export a model using it, cleaner and the mlmodel achieves better performance.

glenn-jocher commented 4 years ago

@dlawrences ok! I've updated requirements.txt now with different sections, with only the base section uncommented. For export I have this. https://github.com/ultralytics/yolov5/blob/7b2b52193d2d715230f8d9dfb660fb95fe3e3628/requirements.txt#L19-L22

I have a feeling I should separate this into it's own requirements/export.txt file, to allow simple export-related pip installs, but I'd like to minimize adding directories and files as much as possible.

torch 1.6 is not compatible with coremltools 4.0b2, and onnx 1.7 has it's own issue with unsupported hardswish layers. I've raised a hardswish issue on https://github.com/onnx/onnx/issues/2728#issuecomment-674179212.

v2.0 models export correctly via both onnx and coremltools using torch 1.5.1 however, so I believe the best workflow would be to train LeakyReLU() models if they are going to require export in the short term, and then to export in a torch 1.5.1 environment.

github-actions[bot] commented 4 years ago

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

OctaM commented 3 years ago

@wmpauli I used your code to eval a coreml model but the bounding boxes are placed wrong. I noticed that you have a function _plot_onebox, can you also share that code so I can double check with mine? bus

Edit: fixed it. The problem was with the strides order. In my model the order had to be stride = [8, 32, 16]. Be sure to check anchors as well

glenn-jocher commented 3 years ago

CoreML export preview:

Screen Shot 2021-03-18 at 11 50 40 PM
OctaM commented 3 years ago

Hi @glenn-jocher. Image 19-03-2021 at 13 26

My neural network has 29mb instead of 7.7mb, I assume that yours it's the quantized version.

Also my type is Neural Network, yours is Neural Network -> Non Maximum Suppression. Are there any additional steps to make when exporting or quantizing the network in order to add nms to my .mlmodel

Thanks

pytholic commented 3 years ago

@OctaM Hi. Did you figure out that NMS part?

pytholic commented 3 years ago

CoreML export preview: Screen Shot 2021-03-18 at 11 50 40 PM

@glenn-jocher Can we get more info on the NMS part in exported model? I also opened a new issue https://github.com/ultralytics/yolov5/issues/5157 since my exported mdoel is not working well!