zju3dv / pvnet

Code for "PVNet: Pixel-wise Voting Network for 6DoF Pose Estimation" CVPR 2019 oral
Apache License 2.0
814 stars 145 forks source link

replace torch.gesv with torch.linalg.solve #177

Closed monajalal closed 10 months ago

monajalal commented 1 year ago

I am getting this error since gesv is deprecated. Is it valid to replace it with torch.linalg.solve?

(pvnet) mona@mona-ThinkStation-P7:~/pvnet$ python tools/demo.py
/home/mona/anaconda3/envs/pvnet/lib/python3.10/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='none' instead.
  warnings.warn(warning.format(ret))
No saved models found.
/home/mona/pvnet/lib/ransac_voting_gpu_layer/ransac_voting_gpu.py:544: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead. (Triggered internally at ../aten/src/ATen/native/IndexingUtils.h:27.)
  direct = vertex[bi].masked_select(torch.unsqueeze(torch.unsqueeze(cur_mask, 2), 3))  # [tn,vn,2]
Traceback (most recent call last):
  File "/home/mona/pvnet/tools/demo.py", line 189, in <module>
    demo()
  File "/home/mona/pvnet/tools/demo.py", line 175, in demo
    corner_pred = eval_net(seg_pred, vertex_pred).cpu().detach().numpy()[0]
  File "/home/mona/anaconda3/envs/pvnet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/mona/anaconda3/envs/pvnet/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py", line 169, in forward
    return self.module(*inputs[0], **kwargs[0])
  File "/home/mona/anaconda3/envs/pvnet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/mona/pvnet/tools/demo.py", line 55, in forward
    return ransac_voting_layer_v3(mask, vertex_pred, 512, inlier_thresh=0.99)
  File "/home/mona/pvnet/lib/ransac_voting_gpu_layer/ransac_voting_gpu.py", line 594, in ransac_voting_layer_v3
    all_win_pts=torch.matmul(b_inv(ATA),torch.unsqueeze(ATb,2)) # [vn,2,1]
  File "/home/mona/pvnet/lib/ransac_voting_gpu_layer/ransac_voting_gpu.py", line 511, in b_inv
    b_inv, _ = torch.gesv(eye, b_mat)
AttributeError: module 'torch' has no attribute 'gesv'
    # b_inv, _ = torch.gesv(eye, b_mat)
    b_inv = torch.linalg.solve(eye, b_mat)

I am able to get results from running the demo: image