NVIDIA-AI-IOT / torch2trt

An easy to use PyTorch to TensorRT converter
MIT License
4.58k stars 675 forks source link

Can nn.PixelShuffle() be converted successfully? #389

Closed PressEtoRace closed 4 years ago

PressEtoRace commented 4 years ago

When I convert the model to tensorrt on AgX Xavier, the error log is shown below. I think this is because torch2trt does not support nn.PixelShuffle(). I would like to ask has somebody converted "nn.PixelShuffle ()" , or tell me how to do the conversion.

Warning: Encountered known unsupported method torch.nn.functional.pixel_shuffle [TensorRT] ERROR: (Unnamed Layer 2) [Scale]: shift weights has count 1024 but 1 was expected Traceback (most recent call last): File "test.py", line 97, in test_model = Pose_post() File "/home/dig/workspace/Real_Action_Recognition_611ubuntu/action_recognition/PoseEstimation/AlphaPose/SPPE/src/main_fast_inference.py", line 115, in init model_trt = torch2trt(model, [x]) File "/usr/local/lib/python3.6/dist-packages/torch2trt-0.1.0-py3.6-linux-aarch64.egg/torch2trt/torch2trt.py", line 436, in torch2trt outputs = module(inputs) File "/home/dig/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in call result = self.forward(*input, kwargs) File "/home/dig/workspace/Real_Action_Recognition_611ubuntu/action_recognition/PoseEstimation/AlphaPose/SPPE/src/models/FastPose.py", line 61, in forward out = self.duc1(out) File "/home/dig/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in call result = self.forward(*input, *kwargs) File "/home/dig/workspace/Real_Action_Recognition_611ubuntu/action_recognition/PoseEstimation/AlphaPose/SPPE/src/models/layers/DUC.py", line 21, in forward x = self.relu(x) File "/home/dig/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in call result = self.forward(input, kwargs) File "/usr/local/lib/python3.6/dist-packages/torch2trt-0.1.0-py3.6-linux-aarch64.egg/torch2trt/torch2trt.py", line 218, in wrapper converter"converter" File "/usr/local/lib/python3.6/dist-packages/torch2trt-0.1.0-py3.6-linux-aarch64.egg/torch2trt/converters/ReLU.py", line 7, in convert_ReLU inputtrt = trt(ctx.network, input) File "/usr/local/lib/python3.6/dist-packages/torch2trt-0.1.0-py3.6-linux-aarch64.egg/torch2trt/torch2trt.py", line 132, in trt_ t._trt.shape ValueError: len() should return >= 0

464hee commented 1 year ago

Hello, have you made a relevant implementation for this method yet?