I hope everything is OK. I am tring to use the pretrained model of psmnet to make a prediction. But I got some error as shown blow:
Traceback (most recent call last):
File "predict.py", line 199, in
main()
File "predict.py", line 162, in main
pred_disp = aanet(left, right)[-1] # [B, H, W]
File "/usr/local/envs/aanet/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in call
result = self.forward(*input, kwargs)
File "/usr/local/envs/aanet/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 150, in forward
return self.module(*inputs[0], *kwargs[0])
File "/usr/local/envs/aanet/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in call
result = self.forward(input, kwargs)
File "/content/drive/My Drive/AANet/nets/aanet.py", line 207, in forward
left_feature = self.feature_extraction(left_img)
File "/content/drive/My Drive/AANet/nets/aanet.py", line 136, in feature_extraction
feature = self.fpn(feature)
File "/usr/local/envs/aanet/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in call
result = self.forward(*input, **kwargs)
File "/content/drive/My Drive/AANet/nets/feature.py", line 213, in forward
assert len(self.in_channels) == len(inputs)
AssertionError
Could you give some advises? Any advises is appriciate!
Hi Hello:
I hope everything is OK. I am tring to use the pretrained model of psmnet to make a prediction. But I got some error as shown blow:
Traceback (most recent call last): File "predict.py", line 199, in
main()
File "predict.py", line 162, in main
pred_disp = aanet(left, right)[-1] # [B, H, W]
File "/usr/local/envs/aanet/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in call
result = self.forward(*input, kwargs)
File "/usr/local/envs/aanet/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 150, in forward
return self.module(*inputs[0], *kwargs[0])
File "/usr/local/envs/aanet/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in call
result = self.forward(input, kwargs)
File "/content/drive/My Drive/AANet/nets/aanet.py", line 207, in forward
left_feature = self.feature_extraction(left_img)
File "/content/drive/My Drive/AANet/nets/aanet.py", line 136, in feature_extraction
feature = self.fpn(feature)
File "/usr/local/envs/aanet/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in call
result = self.forward(*input, **kwargs)
File "/content/drive/My Drive/AANet/nets/feature.py", line 213, in forward
assert len(self.in_channels) == len(inputs)
AssertionError
Could you give some advises? Any advises is appriciate!