nileshkulkarni / csm

Code release for "Canonical Surface Mapping via Geometric Cycle Consistency"
https://nileshkulkarni.github.io/csm/
186 stars 31 forks source link

GCC loss implementation #22

Open meijie0401 opened 3 years ago

meijie0401 commented 3 years ago

I want to implement GCC loss in my project and I'm trying to figure out the details of your implementation.

Question-0: the branch predicting uv for each foreground pixel (used for GCC loss) is only used during training. So it is only used to indirectly help train the 'camera pose' branch, and 'deformation' branch and won't be used during inference. Is my understanding correct?

Here are your relevant code and my understanding/questions of the implementation of GCC loss.

Step-1, Set ground truth 'self.codes_gt['xy_map']', with linspace function (easy to understand~):

def get_sample_grid(img_size):   | x = torch.linspace(-1, 1, img_size[1]).view(1, -1).repeat(img_size[0],1)   | y = torch.linspace(-1, 1, img_size[0]).view(-1, 1).repeat(1, img_size[1])   | grid = torch.cat((x.unsqueeze(2), y.unsqueeze(2)), 2)   | grid.unsqueeze(0)   | return grid

self.grid = cub_parse.get_sample_grid(self.upsample_img_size).repeat( 1, 1, 1, 1).to(self.device)

grid = self.grid.repeat(b_size,1,1,1) self.codes_gt['xy_map'] = grid

Step-2, map predicted UV for each pixel to 3d vertices location, get 'points3d':

def project_uv_to_3d(uv2points, uv_map):   | B = uv_map.size(0)   | H = uv_map.size(1)   | W = uv_map.size(2)   | uv_map_flatten = uv_map.view(-1, 2)   | points3d = uv2points.forward(uv_map_flatten)   | points3d = points3d.view(B, H*W, 3)   | return points3d

self.uv2points = cub_parse.UVTo3D(self.mean_shape) points3d = geom_utils.project_uv_to_3d(self.uv2points, codes_pred['uv_map'])

_Question-1: Why do you map UV to 3D first using 'UVTo3D', then use 'project_uv_to3d'? What's the graphics formula/theory behind these two functions?

Step-3, orthographic project points3d to image plane, get 'codes_pred['project_points']'

def project_3d_to_image(points3d, cam, offset_z):

  | projected_points = orthographic_proj_withz(points3d, cam, offset_z)   | return projected_points

def orthographic_proj_withz(X, cam, offset_z=0.):   | """   | X: B x N x 3   | cam: B x 7: [sc, tx, ty, quaternions]   | Orth preserving the z.   | """   | quat = cam[:, -4:]   | X_rot = quat_rotate(X, quat)   | scale = cam[:, 0].contiguous().view(-1, 1, 1)   | trans = cam[:, 1:3].contiguous().view(cam.size(0), 1, -1)   |     | proj = scale * X_rot   |     | proj_xy = proj[:, :, :2] + trans   | proj_z = proj[:, :, 2, None] + offset_z   |     | return torch.cat((proj_xy, proj_z), 2)

codes_pred['project_points_cam_pred'] = geom_utils.project_3d_to_image(points3d, codes_pred['cam'], self.offset_z) codes_pred['project_points_cam_pred'] = codes_pred['project_points_cam_pred'][..., 0:2].view(self.codes_gt['xy_map'].size()) codes_pred['project_points'] = codes_pred['project_points_cam_pred']

Step-4: GCC L2 loss between 'codes_pred['project_points''], and 'self.codes_gt['xy_map']'.

Reprojection Loss project_points = codes_pred['project_points']   | if opts.ignore_mask_gcc:   | reproject_loss = reproject_loss_l2(project_points, codes_gt['xy_map'], seg_mask*0+1)   | else:   | reproject_loss = reproject_loss_l2(project_points, codes_gt['xy_map'], seg_mask)

_Question-2: why does the 'codes_pred['project_points']' obtained by 'orthographic_projwithz' range from 0 to 255 (e.g. when input image size is 256x256) ? I think 'points3d' is already in this range, but how does step-2 make this happen?