MrHarsh10 / tspi_-RKNN_MobileNetV3

将mobliebetv3转换成RKNN部署到泰山派上
16 stars 0 forks source link

模型转换错误 #1

Open smilepolker opened 5 months ago

smilepolker commented 5 months ago

问题描述:在执行训练完后生成的MobileNetV3.pt无法通过pt2rknn.py转换为tspi_moblienetv3_demo.rknn 版本:当前版本 复现步骤:

  1. python train.py
  2. python pt2rknn.py 日志:
    
    (rknn) chgr@gpu15:~/tspi_RKNN_MobileNetV3$ python pt2rknn.py
    I rknn-toolkit2 version: 2.0.0b0+9bab5682
    --> Config model
    done
    --> Loading model
    W load_pytorch: Catch exception when torch.jit.load:
    RuntimeError('PytorchStreamReader failed locating file constants.pkl: file not found')
    W load_pytorch: Make sure that the torch version of '/home/chgr/tspi_RKNN_MobileNetV3/MobileNetV3.pt' is consistent with the installed torch version '2.1.0+cu121'!
    E load_pytorch: Traceback (most recent call last):
    E load_pytorch:   File "rknn/api/rknn_base.py", line 1590, in rknn.api.rknn_base.RKNNBase.load_pytorch
    E load_pytorch:   File "/home/chgr/.conda/envs/rknn/lib/python3.10/site-packages/torch/jit/_serialization.py", line 162, in load
    E load_pytorch:     cpp_module = torch._C.import_ir_module(cu, str(f), map_location, _extra_files, _restore_shapes)  # type: ignore[call-arg]
    E load_pytorch: RuntimeError: PytorchStreamReader failed locating file constants.pkl: file not found
    W If you can't handle this error, please try updating to the latest version of the toolkit2 and runtime from:
    https://console.zbox.filez.com/l/I00fc3 (Pwd: rknn)  Path: RKNPU2_SDK / 2.X.X / develop /
    If the error still exists in the latest version, please collect the corresponding error logs and the model,
    convert script, and input data that can reproduce the problem, and then submit an issue on:
    https://redmine.rock-chips.com (Please consult our sales or FAE for the redmine account)
    Load model failed!
MrHarsh10 commented 5 months ago

根据报错,你可能没有安装rknn toolkit2 的PC的WHEEL

smilepolker commented 5 months ago

额,我应该是安装了的,我这边的环境时使用cpu进行训练,不确定是否和这个环境有关,但是现在已经解决这个问题了,我对保存MobileNetV3.pt的代码进行了修改,然后就可以正常转换了。 修改的文件是train.py, 代码行数是:95

    input_tensor = torch.rand(1, 3, 224, 224)
    traced_script_module = torch.jit.trace(model, input_tensor)
    # save
    traced_script_module.save("MobileNetV3.pt")

    # torch.save(model, "MobileNetV3.pt")

你能你能否帮忙检查下我的代码是否有问题,然后将你的代码仓库更新下