cxliu0 / PET

[ICCV 2023] Point-Query Quadtree for Crowd Counting, Localization, and More
MIT License
59 stars 5 forks source link

转换onnx的问题 #32

Open yhgyhg123 opened 3 days ago

yhgyhg123 commented 3 days ago

我试着将pt文件转为onnx文件,代码出现以下报错 Traceback (most recent call last): File "/home/server/YHG/PET_main/pth_onnx.py", line 100, in convert_pth_to_onnx(args) File "/home/server/YHG/PET_main/pth_onnx.py", line 77, in convert_pth_to_onnx output = model(input_data) File "/home/server/anaconda3/envs/yhgyhg/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(input, kwargs) File "/home/server/YHG/PET_main/models/pet.py", line 304, in forward out = self.test_forward(samples, features, pos, kwargs)
File "/home/server/YHG/PET_main/models/pet.py", line 357, in test_forward outputs = self.pet_forward(samples, features, pos, kwargs) File "/home/server/YHG/PET_main/models/pet.py", line 327, in pet_forward outputs_sparse = self.quadtree_sparse(samples, features, context_info, kwargs) File "/home/server/anaconda3/envs/yhgyhg/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(
input, kwargs) File "/home/server/YHG/PET_main/models/pet.py", line 166, in forward hs = self.transformer(encode_src, src_pos_embed, mask, pqs, img_shape=samples.tensors.shape[-2:], kwargs) File "/home/server/anaconda3/envs/yhgyhg/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "/home/server/YHG/PET_main/models/transformer/prog_win_transformer.py", line 141, in forward hs = self.decoder_forward(query_feats, query_embed, File "/home/server/YHG/PET_main/models/transformer/prog_win_transformer.py", line 96, in decoder_forward queryembed = query_embed.permute(1,2,0).reshape(bs, c, qH, qW) RuntimeError: shape '[1, 256, 50, 256]' is invalid for input of size 1638400

cxliu0 commented 1 day ago

What is the resolution of the input image? It is suggested to ensure that the resolution is divisible by 256, e.g., the resolution of 1024x768 is acceptable.

yhgyhg123 commented 1 day ago

定义输入张量

input_data = torch.randn(args.batch_size, 3, *args.img_size).to(device)
output = model(input_data)

# ONNX导出
model_name = args.resume.replace('.pth', '.onnx')
input_name = 'input'
pred_logits = 'pred_logits'
pred_points = 'pred_points'
torch.onnx.export(model,
              input_data,
              model_name,
              opset_version=11,
              input_names=[input_name],
              output_names=[pred_logits, pred_points],
              dynamic_axes={
                  input_name: {0: 'batch_size'},
                  pred_logits: {0: 'batch_size'},
                  pred_points: {0: 'batch_size'}}

                 我尝试了图片设置为768*768,还是一样的报错,不知道是什么原因,还得麻烦你看看
                 RuntimeError: shape '[1, 256, 72, 256]' is invalid for input of size 2359296