zhangy76 / KNOWN

Code repository for Body Knowledge and Uncertainty Modeling for Monocular 3D Human Body Reconstruction
MIT License
12 stars 0 forks source link

Release train code? #2

Closed yxt7979 closed 10 months ago

yxt7979 commented 1 year ago

Hii wonderful work! It is really a good idea to predict the uncertainty of parameters (the covariance matrix) and use NLL loss to supervise it. Hope you can release the uncertainty and NLL loss code!

Look forward to your reply!

zhangy76 commented 1 year ago

Hi,

Thank you for your interest! I've been quite busy lately and may not be able to provide detailed and organized code at the moment. I cleaned up some rough code, hope it can be helpful.

For computing the NLL loss,

def keypoint_loss_NLL(pred_keypoints_2d, gt_keypoints_2d, sigma_kp, Jidx, has_pose_2d, device, sigma_kp_weight=1):
    """ Compute NLL given 2D keypoint labels.
    Args:
        pred_keypoints_2d: [N, J, 3] predicted 2D keypoints (this should be the samples from the 2D keypoint prediction distribution)
        gt_keypoints_2d: [N, J, 3] ground truth 2D keypoints 
        sigma_kp: [N, J, 2], variance of the 2D keypoint prediction
        sigma_kp_weight: scalar, weight on regularization term
    Return
        kp_loss_recon: scalar, MSE loss
        kp_loss_NLL: scalar, NLL loss
        kp_sigma: scalar, regularization on the covariance matrix
    """

    gt_keypoints_2d   = gt_keypoints_2d[has_pose_2d == 1].clone()
    num_valid = len(gt_keypoints_2d)

    if num_valid > 0:
        pred_keypoints_2d = pred_keypoints_2d[has_pose_2d == 1].clone()
        Jidx              = Jidx[has_pose_2d == 1].clone()
        sigma_kp = sigma_kp[has_pose_2d == 1].clone().view(-1,23,2)

        pred_keypoints_2d = torch.stack([pred_keypoints_2d[i,Jidx[i,:],:] for i in range(num_valid)], dim=0)
        sigma_kp = torch.stack([sigma_kp[i,Jidx[i,:],:] for i in range(num_valid)], dim=0)
        conf = gt_keypoints_2d[:, :, -1].unsqueeze(-1).clone()

        kp_loss_mse = (pred_keypoints_2d-gt_keypoints_2d[:,:,:-1])**2
        kp_loss_recon = (kp_loss_mse * conf)
        kp_loss_NLL = ((kp_loss_mse /2/sigma_kp**2 + sigma_kp_weight*torch.log(sigma_kp)) * conf)
        kp_sigma = (torch.log(sigma_kp) * conf).mean()
        return kp_loss_recon.sum(2).mean(), kp_loss_NLL.mean(dim=2), kp_sigma
    else:
        return torch.FloatTensor(1).fill_(0.).to(device), torch.FloatTensor(1).fill_(0.).to(device), torch.FloatTensor(1).fill_(0.).to(device)

For quantifying the aleatoric uncertainy, it can be calculated as the trace of the covariane matrix (pred_kp_sigma). For computing the epistemic uncertainty of the 2D keypoint prediction,

def Ue_kps(pred_pose_mean, pred_beta_mean, pred_cam_mean, pred_pose_sigma, pred_beta_sigma, pred_kp_sigma, device, num_samples=100):
    Args:
        pred_pose_mean: [N, 144] 
        pred_beta_mean: [N, 10] 
        pred_cam_mean: [N, 3] 
        pred_pose_sigma: [N, 144] 
        pred_beta_sigma: [N, 10]
        pred_kp_sigma: [N, J*2]
    Return
        model_kp_sigma: [N, J*2]
    """
    curr_batch_size = pred_pose_mean.shape[0]
    model_kp_sigma = np.zeros([curr_batch_size, 46])
    for b in range(curr_batch_size):
        epsilon = torch.normal(0, 1, size=(num_samples, 138+10+46), device=device)
        pred_pose_sample = pred_pose_mean[b:b+1].expand(num_samples,-1).clone()
        pred_pose_sample[:,6:] = pred_pose_sample[:,6:] + epsilon[:,:138]*pred_pose_sigma[b:b+1]
        pred_pose_sample_rotmat = rot6d_to_rotmat(pred_pose_sample).view(num_samples, 24, 3, 3)

        pred_beta_sample = pred_beta_mean[b:b+1] + epsilon[:,138:148]*pred_beta_sigma[b:b+1]
        pred_output_sample = smpl.forward(betas=pred_beta_sample, thetas_rotmat=pred_pose_sample_rotmat)
        pred_vertices_sample = pred_output_sample.vertices
        pred_joints_sample = pred_output_sample.joints

        pred_cam_sample = pred_cam_mean[b:b+1]
        pred_cam_t_sample = torch.stack([pred_cam_sample[:,1],
                                  pred_cam_sample[:,2],
                                  focal_length/(pred_cam_sample[:,0] +1e-9)],dim=-1)

        pred_keypoints_2d_sample = perspective_projection(pred_joints_sample,
                                                   rotation=torch.eye(3, device=device).unsqueeze(0).expand(num_samples, -1, -1),
                                                   translation=pred_cam_t_sample,
                                                   focal_length=focal_length,
                                                   camera_center=torch.ones(num_samples, 2, device=device) * (112))
        pred_keypoints_2d_sample_numpy = pred_keypoints_2d_sample.detach().cpu().numpy().reshape(-1,46)
        model_kp_sigma[b] = np.std(pred_keypoints_2d_sample_numpy, axis=0)
    return model_kp_sigma
zhangy76 commented 10 months ago

Marked the issue as completed due to inactivity. Please feel free to reopen for any questions.