Megvii-BaseDetection / YOLOX

YOLOX is a high-performance anchor-free YOLO, exceeding yolov3~v5 with MegEngine, ONNX, TensorRT, ncnn, and OpenVINO supported. Documentation: https://yolox.readthedocs.io/
Apache License 2.0
9.43k stars 2.21k forks source link

export torchscript.py #1691

Open lbg030 opened 1 year ago

lbg030 commented 1 year ago

I'm trying to export pt file to a .jit file. However, if you look at the current torchscript file, there are no preprocessing and postprocessing codes. Is there a way to export pre-processing and post-processing to a .jit file at once?

zawlin commented 1 year ago

You can try something like this. Then the fused.ts.pt is all in one. You give it any shape of input and get back correct results. No processing needed. There's probably some improvements that can be made like for example the decode output is done in cpu to avoid torchscript pinning device indices and most importantly results in nms being done in cpu instead of gpu. That may or may not result in slightly worse performance.

import torch
import torchvision
from torch import nn
from yolox.exp import get_exp

img_size=1280
device='cuda:0'

ckpt_filename = './models/yolox_l.pth'
exp = get_exp(None, 'yolox-l')
model = exp.get_model()
ckpt = torch.load(ckpt_filename,map_location='cpu')
model.load_state_dict(ckpt['model'])
model=model.to(device)
model = model.eval()

model.head.decode_in_inference = False
traced_model_filename = './models/yolox_l.ts.pt'
#input_size = [1, 3, *exp.test_size]
input_size = [1,3,img_size,img_size]
print('Exported model input size:', input_size)
dummy_input = torch.randn(*input_size).to(device)
traced_model = torch.jit.trace(model, dummy_input)
traced_model.save(traced_model_filename)
print('Exported model was saved to:', traced_model_filename)

#===============
import numpy as np
import cv2

def get_sample(im_path='ws/uav0000086_00000_v'):
    im = cv2.imread(f'{im_path}/0000002.jpg')
    return torch.from_numpy(np.ascontiguousarray(im.transpose(2, 0, 1))).unsqueeze(0).float()
#===============
from typing import List, Set, Dict, Tuple, Optional
import torchvision.transforms.functional as F

def scale_coords(img1_shape:List[int], coords:torch.Tensor, img0_shape:List[int]):
    # Rescale coords (xyxy) from img1_shape to img0_shape    
    gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1])  # gain  = old / new
    pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2  # wh padding
    coords[:, [0, 2]] -= pad[0]  # x padding
    coords[:, [1, 3]] -= pad[1]  # y padding
    coords[:, :4] /= gain
    clip_coords(coords, img0_shape)

def clip_coords(boxes:torch.Tensor, img_shape:List[int]):
    # Clip bounding xyxy bounding boxes to image shape (height, width)
    boxes[:, 0].clamp_(0, img_shape[1])  # x1
    boxes[:, 1].clamp_(0, img_shape[0])  # y1
    boxes[:, 2].clamp_(0, img_shape[1])  # x2
    boxes[:, 3].clamp_(0, img_shape[0])  # y2

class ResizePad(nn.Module):
    def __init__(self, w=1280, h=1280):        
        super(ResizePad, self).__init__()
        self.w = w
        self.h = h

    def forward(self, image:torch.Tensor )->torch.Tensor:
        h_1, w_1 = image.shape[2:]
        ratio_f = self.w / self.h
        ratio_1 = w_1 / h_1
        # padding to preserve aspect ratio
        hp = int(w_1/ratio_f - h_1)
        wp = int(ratio_f * h_1 - w_1)
        if hp > 0 and wp < 0:
            hp = hp // 2
            image = F.pad(image, (0, hp, 0, hp),0, "constant")
            return F.resize(image, [self.h, self.w])
        elif hp < 0 and wp > 0:
            wp = wp // 2
            image = F.pad(image, (wp, 0, wp, 0),0,  "constant")
            return F.resize(image, [self.h, self.w])
        return F.resize(image, [self.h, self.w])

    def _apply(self, fn):
        super(ResizePad, self)._apply(fn)
        return self

class NmsFused(nn.Module):
    def __init__(self,backbone):
        super(NmsFused, self).__init__()
        self.resizer = ResizePad()
        self.backbone = backbone
        self.classes:List[str] = ["person","bicycle","car","motorcycle","airplane","bus","train","truck","boat","traffic light","fire hydrant","stop sign","parking meter","bench","bird","cat","dog","horse","sheep","cow","elephant","bear","zebra","giraffe","backpack","umbrella","handbag","tie","suitcase","frisbee","skis","snowboard","sports ball","kite","baseball bat","baseball glove","skateboard","surfboard","tennis racket","bottle","wine glass","cup","fork","knife","spoon","bowl","banana","apple","sandwich","orange","broccoli","carrot","hot dog","pizza","donut","cake","chair","couch","potted plant","bed","dining table","toilet","tv","laptop","mouse","remote","keyboard","cell phone","microwave","oven","toaster","sink","refrigerator","book","clock","vase","scissors","teddy bear","hair drier","toothbrush"]
        #self.hw = [1280,1280]
        self.strides=[8, 16, 32]

    def decode_outputs(self, outputs, dtype):
        grids = []
        strides = []
        #hw = [torch.Size([80, 80]), torch.Size([40, 40]), torch.Size([20, 20])] # for img_size 640
        hw = [torch.Size([160, 160]), torch.Size([80, 80]), torch.Size([40, 40])]        
        for (hsize, wsize), stride in zip(hw, self.strides):
            yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)])
            grid = torch.stack((xv, yv), 2).view(1, -1, 2)
            grids.append(grid)
            shape = grid.shape[:2]
            strides.append(torch.full((shape[0],shape[1], 1), stride))

        grids = torch.cat(grids, dim=1).float()#.type(dtype)
        strides = torch.cat(strides, dim=1).float()#.type(dtype)
        outputs = torch.cat([
            (outputs[..., 0:2] + grids) * strides,
            torch.exp(outputs[..., 2:4]) * strides,
            outputs[..., 4:]
        ], dim=-1)
        return outputs

    def forward(self, img: torch.Tensor,conf_thre:float=0.7,nms_thre:float=0.45,class_agnostic:bool=False)->List[Tuple[torch.Tensor,List[str]]]:
        num_classes:int=80
        x = self.resizer(img)

        prediction = self.backbone(x).cpu()
        prediction = self.decode_outputs(prediction,x.type())
        box_corner = torch.zeros(prediction.shape)
        box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
        box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
        box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
        box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
        prediction[:, :, :4] = box_corner[:, :, :4]

        output:List[Tuple[torch.Tensor,List[str]]] = []
        for i, image_pred in enumerate(prediction):

            # If none are remaining => process next image
            if not image_pred.size(0):
                continue
            # Get score and class with highest confidence
            class_conf, class_pred = torch.max(image_pred[:, 5: 5 + num_classes], 1, keepdim=True)

            conf_mask = (image_pred[:, 4] * class_conf.squeeze() >= conf_thre).squeeze()
            # Detections ordered as (x1, y1, x2, y2, obj_conf, class_conf, class_pred)
            detections = torch.cat((image_pred[:, :5], class_conf, class_pred.float()), 1)
            detections = detections[conf_mask]
            if not detections.size(0):
                continue

            if class_agnostic:
                nms_out_index = torchvision.ops.nms(
                    detections[:, :4],
                    detections[:, 4] * detections[:, 5],
                    nms_thre,
                )
            else:
                nms_out_index = torchvision.ops.batched_nms(
                    detections[:, :4],
                    detections[:, 4] * detections[:, 5],
                    detections[:, 6],
                    nms_thre,
                )
            detections = detections[nms_out_index]

            scale_coords(x.shape[2:], detections[:, :4], img.shape[2:])
            cls_strings:List[str] = [self.classes[int(ii)] for ii in detections[:,6] ]
            output.append((detections,cls_strings))
        return output

    def _apply(self, fn):
        super(NmsFused, self)._apply(fn)
        #self.backbone = fn(self.backbone)
        return self
#===============
inputs = get_sample()

traced_model = torch.jit.load('./models/yolox_l.ts.pt',map_location=device)
to_trace = NmsFused(traced_model).to(device)
inputs = inputs.to(torch.device(device))

output_orig = to_trace(inputs)
script_module = torch.jit.script(to_trace)
script_module.save('models/fused.ts.pt')
results=script_module(inputs)

import opencv_jupyter_ui as jcv2
from typing import List, Set, Dict, Tuple, Optional
import torchvision.transforms.functional as F

model = torch.jit.load("models/fused.ts.pt",map_location=device)

with torch.no_grad():
    im_path = 'ws/uav0000086_00000_v'
    im = cv2.imread(f'{im_path}/0000010.jpg')
    inputs = get_sample()
    inputs=inputs.to(device)
    results=model(inputs,.01)
    for i in range(results[0][0].shape[0]):
        box = results[0][0][i].int().cpu().numpy()
        cv2.rectangle(im,(box[0],box[1]),(box[2],box[3]),(255,0,0),2,-1)        
        font = cv2.FONT_HERSHEY_SIMPLEX
        fontScale = 0.5
        color = (0, 255, 255)
        thickness = 1
        image = cv2.putText(im, results[0][1][i], (box[0],box[1]), font, 
                           fontScale, color, thickness, cv2.LINE_AA)
    jcv2.imshow('im',im)
    jcv2.waitKey(1)