Tianxiaomo / pytorch-YOLOv4

PyTorch ,ONNX and TensorRT implementation of YOLOv4
Apache License 2.0
4.46k stars 1.49k forks source link

Batch Inference doesn't work if TensorRT conversion is done in TensorRT Python API #494

Open HtutLynn opened 2 years ago

HtutLynn commented 2 years ago

First of all, thanks the the great work! I use AlexeyAB's darknet fork for training custom YOLOv4 models. For TensorRT conversion, I use this repo to convert darknet weights to onnx and then later to TensorRT version using trtexec. Using trtexec to convert batch YOLOv4 onnx models to TensorRT models seems to work just fine but when I try to use TensorRT python API to convert batch YOLOv4 onnx models, although conversion works just fine, while in inference with TensorRT batch YOLOv4 models, the inference results are not as expected. The detection model outputs only for the first frame and for the secondth frame, the inference results is zero. I am not sure why it behaves like this. Using trtexec for conversion works just fine but with python api, there seems to be something wrong with the conversion. FYI, I want to use Python API because I want to quantize the model to INT8 model.

Sample Python API based TensorRT conversion script

EXPLICIT_BATCH = []
if trt.__version__[0] >= '7':
    EXPLICIT_BATCH.append(
        1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
print(EXPLICIT_BATCH)

def build_engine(onnx_file_path, engine_file_path, verbose=False, batch_size=1):
    """Takes an ONNX file and creates a TensorRT engine."""
    TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE) if verbose else trt.Logger()
    with trt.Builder(TRT_LOGGER) as builder, builder.create_network(*EXPLICIT_BATCH) as network, trt.OnnxParser(network,
                                                                                                                TRT_LOGGER) as parser:

        if trt.__version__[0] >= '8':
            config = builder.create_builder_config()
            config.max_workspace_size = 1 << 28
            builder.max_batch_size = batch_size
            config.flags = 1 << int(trt.BuilderFlag.FP16)
            # config.flags = strict_type_constraints << int(trt.BuilderFlag.STRICT_TYPES)
        else:
            builder.max_workspace_size = 1 << 28
            builder.max_batch_size = batch_size
            builder.fp16_mode = True
            # builder.strict_type_constraints = True

        # Parse model file
        print('Loading ONNX file from path {}...'.format(onnx_file_path))
        with open(onnx_file_path, 'rb') as model:
            print('Beginning ONNX file parsing')
            if not parser.parse(model.read()):
                print('ERROR: Failed to parse the ONNX file.')
                for error in range(parser.num_errors):
                    print(parser.get_error(error))
                return None
        if trt.__version__[0] >= '7':
            # The actual yolo*.onnx is generated with batch size 64.
            # Reshape input to batch size 1
            shape = list(network.get_input(0).shape)
            print(shape)
            shape[0] = batch_size
            network.get_input(0).shape = shape
        print('Completed parsing of ONNX file')

        print('Building an engine; this may take a while...')
        if trt.__version__[0] >= '8':
            engine = builder.build_engine(network, config)
        else:
            engine = builder.build_cuda_engine(network)
        print('Completed creating engine')
        try:
            with open(engine_file_path, 'wb') as f:
                f.write(engine.serialize())
            return engine
        except:
            traceback.print_exc()