JimmySuen / integral-human-pose

Integral Human Pose Regression
MIT License
471 stars 76 forks source link

error in inference #16

Closed Keysmis closed 5 years ago

Keysmis commented 5 years ago

Hi, I'm interested in your project and modified your test code to inference an image, plot 3D, but, maybe something wrong of my code, the results is wrong, can you cost a little time to have a look? very appreciate! `config = copy.deepcopy(s_config) config.network = get_default_network_config() # defined in blocks config.loss = get_default_loss_config()

config = update_config_from_file(config, s_config_file, check_necessity=True)
config = update_config_from_args(config, s_args)  # config in argument is superior to config in file
net = get_pose_net(config.network, 18).cuda()
ckpt = torch.load(s_args.model)  # or other path/to/model
ckpt_old_key = [key for key in ckpt['network'].keys()]
ckpt_new_key = [key.split('.',1)[-1] for key in ckpt['network'].keys()]
print ('ckpt_old_key:',ckpt_old_key)
print('ckpt_new_key:', ckpt_new_key)
ckpt_new = {}
for i in range(len(ckpt_old_key)):

    ckpt_new[ckpt_new_key[i]] = ckpt['network'][ckpt_old_key[i]]
net.load_state_dict(ckpt_new)
print('in valid')
net.eval()

batch_data = cv2.imread('/workspace/w1/wpc/3d_pose_estimatison/integral-human-pose/data/123.png')
batch_data = cv2.cvtColor(batch_data,cv2.COLOR_BGR2RGB)
batch_data = cv2.resize(batch_data,(288,384))
plt.figure(1)
plt.imshow(batch_data)
plt.show()

batch_data = np.array(batch_data).transpose([2,0,1])
batch_data = torch.from_numpy(np.reshape(batch_data,[1,3,batch_data.shape[1],batch_data.shape[2]])).float()
preds_in_patch_with_score = []
batch_data = batch_data.cuda()
preds = net(batch_data)

batch_data_flip = flip(batch_data, dims=3)
preds_flip = net(batch_data_flip)
patch_width = config.train.patch_width
patch_height = config.train.patch_height
flip_pair = np.array([[1, 4], [2, 5], [3, 6], [14, 11], [15, 12], [16, 13]], dtype=np.int)
pipws = get_joint_location_result(config.loss,config.train.patch_width, config.train.patch_height, preds)
pipws_flip = get_joint_location_result(config.loss,config.train.patch_width, config.train.patch_height, preds_flip)
pipws_flip[:, :, 0] = patch_width - pipws_flip[:, :, 0] - 1
for pair in flip_pair:
    tmp = pipws_flip[:, pair[0], :].copy()
    pipws_flip[:, pair[0], :] = pipws_flip[:, pair[1], :].copy()
    pipws_flip[:, pair[1], :] = tmp.copy()
preds_in_patch_with_score.append((pipws + pipws_flip) * 0.5)
_p = np.asarray(preds_in_patch_with_score)
_p = _p.reshape((_p.shape[0] * _p.shape[1], _p.shape[2], _p.shape[3]))
target_bone_length = 4502.881  # train+val
preds_in_camera_space = []
parent_ids = np.array([0, 0, 1, 2, 0, 4, 5, 0, 17, 17, 8, 17, 11, 12, 17, 14, 15, 0], dtype=np.int)
preds_in_camera_space.append(
    rescale_pose_from_patch_to_camera(_p[0],
                                      target_bone_length,
                                      parent_ids))
preds_in_camera_space = np.asarray(preds_in_camera_space)[:, :, 0:3]
fig = plt.figure(2)
ax = fig.add_subplot((111),projection='3d')
plot_3d_skeleton(ax,preds_in_camera_space[0],parent_ids,flip_pair,title='123',patch_width=patch_width,patch_height=patch_height)

`

lck1201 commented 5 years ago

@Keysmis Hi, we have provided tools for visualization in common/utility/visualization.py. Besides, your code structure is too complex. I recommand generating prediction first, save them, and load them in second stage to plot.

lck1201 commented 5 years ago

Close issue, re-open if you have further questions