Atten4Vis / LW-DETR

This repository is an official implementation of the paper "LW-DETR: A Transformer Replacement to YOLO for Real-Time Detection".
Apache License 2.0
235 stars 14 forks source link

Questions about ONNX and PyTorch Inference #16

Closed MinGiSa closed 3 months ago

MinGiSa commented 3 months ago

I used your code to convert to ONNX, and I wrote the inference code myself. However, when I checked the time, it takes more than 200ms per image. If possible, could you provide torch inference code or onnx inference code? I used my code as follows.

====================================== import os os.environ['CUDA_MODULE_LOADING'] = 'LAZY' import torchvision import argparse import numpy as np from PIL import Image import cv2 import onnxruntime as nxrun import torch import torchvision.transforms as T import tqdm import time

def parser_args(): parser = argparse.ArgumentParser('Object detection using ONNX model') parser.add_argument('--path', type=str, required=True, help='ONNX model file path') parser.add_argument('--image_dir', type=str, required=True, help='Directory containing images to run inference on') parser.add_argument('--output_dir', type=str, required=True, help='Directory to save output images with detections') parser.add_argument('--threshold', type=float, default=0.5, help='Score threshold for displaying bounding boxes') parser.add_argument('--iou_threshold', type=float, default=0.5, help='IoU threshold for non-max suppression') parser.add_argument('--class_names', type=str, required=True, help='Path to class names file') return parser.parse_args()

def findClassNameYOLO(annotationPath): with open(annotationPath, 'r') as file: className = file.read().splitlines() return className

def load_image(file_path): return Image.open(file_path).convert("RGB")

def infer_transforms(): normalize = T.Compose([ T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) return T.Compose([ T.Resize((640, 640)), normalize, ])

def box_cxcywh_to_xyxy(x): x_c, y_c, w, h = x.unbind(-1) b = [(x_c - 0.5 w.clamp(min=0.0)), (y_c - 0.5 h.clamp(min=0.0)), (x_c + 0.5 w.clamp(min=0.0)), (y_c + 0.5 h.clamp(min=0.0))] return torch.stack(b, dim=-1)

def soft_nms(boxes, scores, iou_threshold=0.5, sigma=0.5, score_threshold=0.001): N = boxes.shape[0] indexes = torch.arange(0, N, dtype=torch.float).view(N, 1) dets = torch.cat((boxes, scores.view(N, 1), indexes), dim=1) keep = []

while dets.shape[0]:
    max_idx = torch.argmax(dets[:, 4])
    max_box = dets[max_idx, :4]
    max_score = dets[max_idx, 4]
    keep.append(dets[max_idx, 5].item())

    dets = torch.cat((dets[:max_idx], dets[max_idx+1:]), dim=0)
    if not dets.shape[0]:
        break

    ious = torchvision.ops.box_iou(max_box.unsqueeze(0), dets[:, :4]).squeeze()
    weights = torch.exp(-(ious ** 2) / sigma)
    dets[:, 4] *= weights
    dets = dets[dets[:, 4] > score_threshold]

return torch.tensor(keep, dtype=torch.long)

def generateColors(numClass): colors = [] golden_ratio_conjugate = 0.618033988749895 hue = 0

for i in range(numClass):
    hue += golden_ratio_conjugate
    hue %= 1

    rgb = hsv2rgb(hue, 0.9, 0.95)
    colors.append((int(rgb[2] * 255), int(rgb[1] * 255), int(rgb[0] * 255)))  # BGR format

return colors

def hsv2rgb(h, s, v): if s == 0.0: return (v, v, v)

i = int(h * 6.)
f = (h * 6.) - i
p, q, t = v * (1. - s), v * (1. - s * f), v * (1. - s * (1. - f))
i %= 6

if i == 0:
    return (v, t, p)
if i == 1:
    return (q, v, p)
if i == 2:
    return (p, v, t)
if i == 3:
    return (p, q, v)
if i == 4:
    return (t, p, v)
if i == 5:
    return (v, p, q)

def post_process(outputs, target_sizes, iou_threshold, confidence_threshold): out_logits, out_bbox = outputs['labels'], outputs['dets']

prob = out_logits.sigmoid()
topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), 300, dim=1)
scores = topk_values
topk_boxes = topk_indexes // out_logits.shape[2]
labels = topk_indexes % out_logits.shape[2]
boxes = box_cxcywh_to_xyxy(out_bbox)
boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4))

img_h, img_w = target_sizes.unbind(1)
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
boxes = boxes * scale_fct[:, None, :]

results = []
for s, l, b in zip(scores, labels, boxes):
    keep = s > confidence_threshold
    results.append({
        'scores': s[keep],
        'labels': l[keep],
        'boxes': b[keep]
    })

# Apply NMS
for result in results:
    keep = torchvision.ops.nms(result['boxes'], result['scores'], iou_threshold)
    result['boxes'] = result['boxes'][keep]
    result['scores'] = result['scores'][keep]
    result['labels'] = result['labels'][keep]

return results

def saveImage(image, predictions, className, destPath, fileName, colors, original_size, resized_size): num_detections, detected_boxes, detected_scores, detected_labels = predictions

# Calculate scaling factors
scale_x = original_size[1] / resized_size[1]
scale_y = original_size[0] / resized_size[0]

for i in range(num_detections):
    box = detected_boxes[i]
    score = detected_scores[i]
    label = int(detected_labels[i])

    # Adjust box coordinates to original image size
    start_point = (int(box[0]), int(box[1]))
    end_point = (int(box[2]), int(box[3]))
    color = colors[label]
    thickness = 2

    image = cv2.rectangle(image, start_point, end_point, color, thickness)

    label_text = f"{className[label]}: {score:.2f}"
    (label_width, label_height), baseline = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
    top_left = (start_point[0], start_point[1] - label_height - baseline)
    bottom_right = (start_point[0] + label_width, start_point[1])

    image = cv2.rectangle(image, top_left, bottom_right, color, cv2.FILLED)
    image = cv2.putText(image, label_text, (start_point[0], start_point[1] - baseline), 
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1, cv2.LINE_AA)

savePath = os.path.join(destPath, fileName)
cv2.imwrite(savePath, image)

def infer_onnx(sess, image_dir, output_dir, threshold, iou_threshold, class_names): os.makedirs(output_dir, exist_ok=True) log_file_path = os.path.join(output_dir, "inference_log.txt")

with open(log_file_path, 'w') as log_file:
    image_paths = [os.path.join(image_dir, img_name) for img_name in os.listdir(image_dir) if img_name.endswith('.bmp')]
    colors = generateColors(len(class_names))

    total_time = 0
    min_time = float('inf')
    max_time = 0
    num_images = len(image_paths)

    for idx, img_path in enumerate(tqdm.tqdm(image_paths)):
        image = load_image(img_path)
        width, height = image.size
        orig_target_sizes = torch.Tensor([height, width])
        image_tensor = infer_transforms()(image)

        samples = image_tensor[None].numpy()
        start_time = time.time()
        res = sess.run(None, {"input": samples})
        end_time = time.time()

        outputs = {'labels': torch.Tensor(res[1]), 'dets': torch.Tensor(res[0])}
        orig_target_sizes = torch.stack([orig_target_sizes], dim=0)
        results = post_process(outputs, orig_target_sizes, iou_threshold, threshold)

        process_time = (end_time - start_time) * 1000

        image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
        saveImage(image_cv, (len(results[0]['scores']), results[0]['boxes'], results[0]['scores'], results[0]['labels']), 
                class_names, output_dir, os.path.basename(img_path), colors, (height, width), (height, width))

        total_time += process_time
        min_time = min(min_time, process_time)
        max_time = max(max_time, process_time)

        log_file.write(f"{idx:04d} // Process Time: {process_time:.4f} ms // {img_path}\n")

    avg_time = total_time / num_images
    log_file.write(f"\nMin Inference Time: {min_time:.2f} ms\n")
    log_file.write(f"Max Inference Time: {max_time:.2f} ms\n")
    log_file.write(f"Avg Inference Time: {avg_time:.2f} ms\n")

def main(): args = parser_args() class_names = findClassNameYOLO(args.class_names) sess = nxrun.InferenceSession(args.path) infer_onnx(sess, args.image_dir, args.output_dir, args.threshold, args.iou_threshold, class_names)

if name == 'main': main()

=================================================

Min Inference Time: 217.61 ms Max Inference Time: 295.26 ms Avg Inference Time: 236.82 ms

xbsu commented 3 months ago

Please refer to deploy/benchmark.py for the onnx inference code.

What hardware are you running on? If CPU, this result is reasonable.

MinGiSa commented 3 months ago

Please refer to deploy/benchmark.py for the onnx inference code.

What hardware are you running on? If CPU, this result is reasonable.

thank you. i forgot to use cuda provider and should install onnxruntime-gpu, not a onnxruntime

swrdZWJ commented 2 months ago

hi, could you please provide the onnx you have converted, thanks very much~