hailanyi / CasA

A Cascade Attention Network for 3D Object Detection from LiDAR point clouds
https://ieeexplore.ieee.org/abstract/document/9870747
Apache License 2.0
125 stars 25 forks source link

调用torch.onnx.export()导出时遇到一些问题 #9

Closed ghost closed 1 year ago

ghost commented 1 year ago

作者您好!我想要将发布的CasA-V模型转换到onnx格式做一些测试,把CasA-V.pth文件下载后放在了tools文件夹中,并且在tools/test.py文件中删除掉原来190-195行内容,并在189行后添加如下代码:

model.load_params_from_file(filename=args.ckpt, logger=logger, to_cpu=dist_test)
# set the model to inference mode
model.eval()
# Let's create a dummy input tensor
dummy_input = torch.randn(1, 3, 244, 244, requires_grad=True)

# Export the model
torch.onnx.export(model=model,  # model being run
                  args=dummy_input,  # model input (or a tuple for multiple inputs)
                  f="Output-CasA.onnx",  # where to save the model
                  export_params=True,  # store the trained parameter weights inside the model file
                  opset_version=10,  # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names=['modelInput'],  # the model's input names
                  output_names=['modelOutput'],  # the model's output names
                  dynamic_axes={'modelInput': {0: 'batch_size'},  # variable length axes
                                'modelOutput': {0: 'batch_size'}}
                  )

pycharm报错信息:

Traceback (most recent call last):
  File "/home/jlf/PycharmProjects/CasA/tools/pthToOnnx.py", line 148, in <module>
    main()
  File "/home/jlf/PycharmProjects/CasA/tools/pthToOnnx.py", line 123, in main
    torch.onnx.export(model=model,  # model being run
  File "/home/jlf/miniconda3/lib/python3.8/site-packages/torch/onnx/__init__.py", line 203, in export
    return utils.export(model, args, f, export_params, verbose, training,
  File "/home/jlf/miniconda3/lib/python3.8/site-packages/torch/onnx/utils.py", line 86, in export
    _export(model, args, f, export_params, verbose, training, input_names, output_names,
  File "/home/jlf/miniconda3/lib/python3.8/site-packages/torch/onnx/utils.py", line 526, in _export
    graph, params_dict, torch_out = _model_to_graph(model, args, verbose, input_names,
  File "/home/jlf/miniconda3/lib/python3.8/site-packages/torch/onnx/utils.py", line 366, in _model_to_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "/home/jlf/miniconda3/lib/python3.8/site-packages/torch/onnx/utils.py", line 319, in _trace_and_get_graph_from_model
    torch.jit._get_trace_graph(model, args, strict=False, _force_outplace=False, _return_inputs_states=True)
  File "/home/jlf/miniconda3/lib/python3.8/site-packages/torch/jit/__init__.py", line 338, in _get_trace_graph
    outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
  File "/home/jlf/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/jlf/miniconda3/lib/python3.8/site-packages/torch/jit/__init__.py", line 421, in forward
    graph, out = torch._C._create_graph_by_tracing(
  File "/home/jlf/miniconda3/lib/python3.8/site-packages/torch/jit/__init__.py", line 412, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "/home/jlf/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 720, in _call_impl
    result = self._slow_forward(*input, **kwargs)
  File "/home/jlf/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 704, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/jlf/PycharmProjects/CasA/pcdet/models/detectors/voxel_rcnn.py", line 11, in forward
    batch_dict = cur_module(batch_dict)
  File "/home/jlf/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 720, in _call_impl
    result = self._slow_forward(*input, **kwargs)
  File "/home/jlf/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 704, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/jlf/PycharmProjects/CasA/pcdet/models/backbones_3d/vfe/mean_vfe.py", line 27, in forward
    if 'semi_test' in batch_dict:
  File "/home/jlf/miniconda3/lib/python3.8/site-packages/torch/tensor.py", line 502, in __contains__
    raise RuntimeError(
RuntimeError: Tensor.__contains__ only supports Tensor or scalar, but you passed in a <class 'str'>.

Process finished with exit code 1
hailanyi commented 1 year ago

抱歉,我没有相关经验,不能解决你的问题。