PINTO0309 / PINTO_model_zoo

A repository for storing models that have been inter-converted between various frameworks. Supported frameworks are TensorFlow, PyTorch, ONNX, OpenVINO, TFJS, TFTRT, TensorFlowLite (Float32/16/INT8), EdgeTPU, CoreML.
https://qiita.com/PINTO
MIT License
3.49k stars 566 forks source link

About the PyTorch model conversion to ONNX(324_Ultra-Fast-Lane-Detection-v2) #330

Closed lwplw closed 1 year ago

lwplw commented 1 year ago

Issue Type

Others

OS

Ubuntu

OS architecture

x86_64

Programming Language

C++, Python

Framework

PyTorch, ONNX, TensorRT

Model name and Weights/Checkpoints URL

  1. .pth: https://github.com/cfzd/Ultra-Fast-Lane-Detection-v2

  2. .onnx: https://github.com/PINTO0309/PINTO_model_zoo/tree/main/324_Ultra-Fast-Lane-Detection-v2

  3. onnx2trt: https://github.com/iwatake2222/play_with_tensorrt/tree/master/pj_tensorrt_lane_ultra-fast-lane-detection_v2

Description

How do you convert the pytorch model to onnx?

I refer to this https://github.com/jason-li-831202/Vehicle-CV-ADAS/blob/master/TrafficLaneDetector/convertPytorchToONNX.py for model conversion, but the converted model is different from yours, and pj_tensorrt_lane_ultra-fast-lane-detection_v2 cannot read the model.

This is the onnx model I converted: culane_res34.zip

Relevant Log Output

[03/14/2023-18:44:35] [I] [TRT] Loaded engine size: 413 MiB
[03/14/2023-18:44:35] [I] [TRT] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +0, GPU +10, now: CPU 4005, GPU 1768 (MiB)
[03/14/2023-18:44:35] [I] [TRT] [MemUsageChange] TensorRT-managed allocation in engine deserialization: CPU +0, GPU +412, now: CPU 0, GPU 412 (MiB)
[03/14/2023-18:44:35] [I] [TRT] [MemUsageChange] Init cuBLAS/cuBLASLt: CPU +0, GPU +8, now: CPU 2439, GPU 1672 (MiB)
[03/14/2023-18:44:35] [I] [TRT] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +24, now: CPU 0, GPU 436 (MiB)
[InferenceHelperTensorRt][345] num_of_in_out = 5
[InferenceHelperTensorRt][348] tensor[0]->name: input.1
[InferenceHelperTensorRt][349]   is input = 1
[InferenceHelperTensorRt][353]   dims.d[0] = 1
[InferenceHelperTensorRt][353]   dims.d[1] = 3
[InferenceHelperTensorRt][353]   dims.d[2] = 320
[InferenceHelperTensorRt][353]   dims.d[3] = 1600
[InferenceHelperTensorRt][357]   data_type = 0
[03/14/2023-18:44:35] [E] [TRT] 3: Cannot find binding of given name: input
[InferenceHelperTensorRt][348] tensor[1]->name: 368
[InferenceHelperTensorRt][349]   is input = 0
[InferenceHelperTensorRt][353]   dims.d[0] = 1
[InferenceHelperTensorRt][353]   dims.d[1] = 200
[InferenceHelperTensorRt][353]   dims.d[2] = 72
[InferenceHelperTensorRt][353]   dims.d[3] = 4
[InferenceHelperTensorRt][357]   data_type = 0
[03/14/2023-18:44:35] [E] [TRT] 3: Cannot find binding of given name: loc_row
[03/14/2023-18:44:35] [E] [TRT] 3: Cannot find binding of given name: loc_col
[03/14/2023-18:44:35] [E] [TRT] 3: Cannot find binding of given name: exist_row
[03/14/2023-18:44:35] [E] [TRT] 3: Cannot find binding of given name: exist_col
[InferenceHelperTensorRt][348] tensor[2]->name: 371
[InferenceHelperTensorRt][349]   is input = 0
[InferenceHelperTensorRt][353]   dims.d[0] = 1
[InferenceHelperTensorRt][353]   dims.d[1] = 100
[InferenceHelperTensorRt][353]   dims.d[2] = 81
[InferenceHelperTensorRt][353]   dims.d[3] = 4
[InferenceHelperTensorRt][357]   data_type = 0
[03/14/2023-18:44:35] [E] [TRT] 3: Cannot find binding of given name: loc_row
[03/14/2023-18:44:35] [E] [TRT] 3: Cannot find binding of given name: loc_col
[03/14/2023-18:44:35] [E] [TRT] 3: Cannot find binding of given name: exist_row
[03/14/2023-18:44:35] [E] [TRT] 3: Cannot find binding of given name: exist_col
[InferenceHelperTensorRt][348] tensor[3]->name: 374
[InferenceHelperTensorRt][349]   is input = 0
[InferenceHelperTensorRt][353]   dims.d[0] = 1
[InferenceHelperTensorRt][353]   dims.d[1] = 2
[InferenceHelperTensorRt][353]   dims.d[2] = 72
[InferenceHelperTensorRt][353]   dims.d[3] = 4
[InferenceHelperTensorRt][357]   data_type = 0
[03/14/2023-18:44:35] [E] [TRT] 3: Cannot find binding of given name: loc_row
[03/14/2023-18:44:35] [E] [TRT] 3: Cannot find binding of given name: loc_col
[03/14/2023-18:44:35] [E] [TRT] 3: Cannot find binding of given name: exist_row
[03/14/2023-18:44:35] [E] [TRT] 3: Cannot find binding of given name: exist_col
[InferenceHelperTensorRt][348] tensor[4]->name: 377
[InferenceHelperTensorRt][349]   is input = 0
[InferenceHelperTensorRt][353]   dims.d[0] = 1
[InferenceHelperTensorRt][353]   dims.d[1] = 2
[InferenceHelperTensorRt][353]   dims.d[2] = 81
[InferenceHelperTensorRt][353]   dims.d[3] = 4
[InferenceHelperTensorRt][357]   data_type = 0
[03/14/2023-18:44:35] [E] [TRT] 3: Cannot find binding of given name: loc_row
[03/14/2023-18:44:35] [E] [TRT] 3: Cannot find binding of given name: loc_col
[03/14/2023-18:44:35] [E] [TRT] 3: Cannot find binding of given name: exist_row
[03/14/2023-18:44:35] [E] [TRT] 3: Cannot find binding of given name: exist_col
[ERR: InferenceHelperTensorRt][216] Input tensor doesn't exist in the model (input)
[ERR: LaneEngine][157] Inference helper is not created
Initialization Error

URL or source code for simple inference testing code

https://github.com/iwatake2222/play_with_tensorrt/tree/master/pj_tensorrt_lane_ultra-fast-lane-detection_v2

PINTO0309 commented 1 year ago
import torch, os
from utils.common import merge_config, get_model
from evaluation.eval_wrapper import eval_lane
import torch

if __name__ == "__main__":
    torch.backends.cudnn.benchmark = True

    args, cfg = merge_config()

    distributed = False
    if 'WORLD_SIZE' in os.environ:
        distributed = int(os.environ['WORLD_SIZE']) > 1
    cfg.distributed = distributed
    # if distributed:
    #     torch.cuda.set_device(args.local_rank)
    #     torch.distributed.init_process_group(backend='nccl', init_method='env://')

    net = get_model(cfg)

    state_dict = torch.load(cfg.test_model, map_location = 'cpu')['model']
    compatible_state_dict = {}
    for k, v in state_dict.items():
        if 'module.' in k:
            compatible_state_dict[k[7:]] = v
        else:
            compatible_state_dict[k] = v

    net.load_state_dict(compatible_state_dict, strict = True)

    # if distributed:
    #     net = torch.nn.parallel.DistributedDataParallel(net, device_ids = [args.local_rank])

    # if not os.path.exists(cfg.test_work_dir):
    #     os.mkdir(cfg.test_work_dir)

    eval_lane(net, cfg)

    import onnx
    from onnxsim import simplify
    RESOLUTION = [
        # [320,1600],
        # [384,640],
        # [480,640],
        # [450,800],
        # [720,1280],
        [cfg.train_height, cfg.train_width],
    ]
    BACKBONE=f'{cfg.dataset.lower()}_res{cfg.backbone}'
    MODEL = f'ufldv2_{BACKBONE}'
    for H, W in RESOLUTION:
        onnx_file = f"{MODEL}_{H}x{W}.onnx"
        x = torch.randn(1, 3, H, W).cuda()
        torch.onnx.export(
            net,
            args=(x),
            f=onnx_file,
            opset_version=11,
            input_names=['input'],
            output_names=['loc_row','loc_col','exist_row','exist_col'],
        )
        model_onnx1 = onnx.load(onnx_file)
        model_onnx1 = onnx.shape_inference.infer_shapes(model_onnx1)
        onnx.save(model_onnx1, onnx_file)

        model_onnx2 = onnx.load(onnx_file)
        model_simp, check = simplify(model_onnx2)
        onnx.save(model_simp, onnx_file)
        model_onnx2 = onnx.load(onnx_file)
        model_simp, check = simplify(model_onnx2)
        onnx.save(model_simp, onnx_file)
        model_onnx2 = onnx.load(onnx_file)
        model_simp, check = simplify(model_onnx2)
        onnx.save(model_simp, onnx_file)

    import sys
    sys.exit(0)