zju3dv / ENeRF

SIGGRAPH Asia 2022: Code for "Efficient Neural Radiance Fields for Interactive Free-viewpoint Video"
https://zju3dv.github.io/enerf
Other
413 stars 28 forks source link

Doubt with homo_warp #37

Open smontode24 opened 1 year ago

smontode24 commented 1 year ago

Hi,

I was working with your code and when reviewing the projection of the feature maps into the cost volume, there's something that I don't understand. In the function homo_warp, the projection matrix that is computed is to go from the target view camera coordinates to the source view in order to interpolate the features:

def get_proj_mats(batch, src_scale, tar_scale):
    B, S_V, C, H, W = batch['src_inps'].shape
    src_ext = batch['src_exts']
    src_ixt = batch['src_ixts'].clone()
    src_ixt[:, :, :2] *= src_scale
    src_projs = src_ixt @ src_ext[:, :, :3]

    tar_ext = batch['tar_ext']
    tar_ixt = batch['tar_ixt'].clone()
    tar_ixt[:, :2] *= tar_scale
    tar_projs = tar_ixt @ tar_ext[:, :3]
    tar_ones = torch.zeros((B, 1, 4)).to(tar_projs.device)
    tar_ones[:, :, 3] = 1
    tar_projs = torch.cat((tar_projs, tar_ones), dim=1)
    tar_projs_inv = torch.inverse(tar_projs)

    src_projs = src_projs.view(B, S_V, 3, 4)
    tar_projs_inv = tar_projs_inv.view(B, 1, 4, 4)

    proj_mats = src_projs @ tar_projs_inv
    return proj_mats

But when projecting the grid into the image, I don't understand which coordinates are used. Only pixel indices seemed to be used and are projected into the source image by slicing the projection matrix into rotation and translation, when it also contains the intrinsic matrix:

def homo_warp(src_feat, proj_mat, depth_values, batch):
    B, D, H_T, W_T = depth_values.shape
    C, H_S, W_S = src_feat.shape[1:]
    device = src_feat.device

    R = proj_mat[:, :, :3] # (B, 3, 3)
    T = proj_mat[:, :, 3:] # (B, 3, 1)
    # create grid from the ref frame
    ref_grid = create_meshgrid(H_T, W_T, normalized_coordinates=False,
                               device=device) # (1, H, W, 2)
    ref_grid = ref_grid.permute(0, 3, 1, 2) # (1, 2, H, W)
    ref_grid = ref_grid.reshape(1, 2, H_T*W_T) # (1, 2, H*W)
    ref_grid = ref_grid.expand(B, -1, -1) # (B, 2, H*W)
    ref_grid = torch.cat((ref_grid, torch.ones_like(ref_grid[:,:1])), 1) # (B, 3, H*W)
    ref_grid_d = ref_grid.repeat(1, 1, D) # (B, 3, D*H*W)
    src_grid_d = R @ ref_grid_d + T/depth_values.view(B, 1, D*H_T*W_T)
    del ref_grid_d, ref_grid, proj_mat, R, T, depth_values # release (GPU) memory

    # project negative depth pixels to somewhere outside the image
    # negative_depth_mask = src_grid_d[:, 2:] <= 1e-7
    # src_grid_d[:, 0:1][negative_depth_mask] = W
    # src_grid_d[:, 1:2][negative_depth_mask] = H
    # src_grid_d[:, 2:3][negative_depth_mask] = 1

    src_grid = src_grid_d[:, :2] / torch.clamp_min(src_grid_d[:, 2:], 1e-6) # divide by depth (B, 2, D*H*W)
    # del src_grid_d
    src_grid[:, 0] = (src_grid[:, 0])/((W_S - 1) / 2) - 1 # scale to -1~1
    src_grid[:, 1] = (src_grid[:, 1])/((H_S - 1) / 2) - 1 # scale to -1~1
    src_grid = src_grid.permute(0, 2, 1) # (B, D*H*W, 2)
    src_grid = src_grid.view(B, D, H_T*W_T, 2)

    warped_src_feat = F.grid_sample(src_feat, src_grid,
                                    mode='bilinear', padding_mode='zeros',
                                    align_corners=True) # (B, C, D, H*W)
    warped_src_feat = warped_src_feat.view(B, C, D, H_T, W_T)
    src_grid = src_grid.view(B, D, H_T, W_T, 2)
    if torch.isnan(warped_src_feat).isnan().any():
        __import__('ipdb').set_trace()
    return warped_src_feat, src_grid

Could you explain how are the coordinates from the grid projected into the source image and in which coordinate system is the grid defined?

Thanks in advance, Sergio