sony / genwarp

MIT License
218 stars 18 forks source link

Camera Convention #13

Open mohammadasim98 opened 1 week ago

mohammadasim98 commented 1 week ago

Hi,

First of all, great work! It is very impressive. However I lack some information about the camera convention you are using i.e., opencv/colmap, Opengl etc. I would like to know the exact camera convention used for the extrinsics i.e., relative view matrix.

Thank you,

Best, Asm

j0seo commented 5 days ago

Hello. Thank you for your interest in our work. We are planning to update soon to support R|t, K conditions that would be more familiar to researchers. In the meantime, if you'd like to implement this, you can temporarily input R, t, and K into the model by modifying our forward warping function here as shown below. Please note that the below code is unorganized and might need some debugging - sorry for this.

def forward_warper(images, depths, R, T, K=None, H=512, W=512, device=None):

def pi_inv(K, x, d):
    fx, fy, cx, cy = K[:, 0:1, 0:1], K[:, 1:2,
                                       1:2], K[:, 0:1, 2:3], K[:, 1:2, 2:3]
    X_x = d * (x[..., 0] - cx) / fx
    X_y = d * (x[..., 1] - cy) / fy
    X_z = d
    X = torch.stack([X_x, X_y, X_z], dim=-1)
    return X

def x_2d_coords(h, w, device):
    x_2d = torch.zeros((h, w, 2), device=device)
    for y in range(0, h):
        x_2d[y, :, 1] = y
    for x in range(0, w):
        x_2d[:, x, 0] = x
    return x_2d

def transpose(R, t, X):
    b, h, w, c = X.shape
    X = rearrange(X, 'b h w c -> b c (h w)')
    X_after_R = R@X + t[:, :, None]
    X_after_R = rearrange(X_after_R, 'b c (h w) -> b h w c', h=h)
    return X_after_R

if K is None:
    focal = (5.8269e+02, 5.8269e+02)
    K = torch.tensor([
        [focal[0], 0., W/2],
        [0., focal[1], H/2],
        [0., 0., 1.]], device=device)
    K = K[None, ...]

if isinstance(depths, np.ndarray):
    depths = torch.tensor(depths).to(device)
if isinstance(R, np.ndarray):
    R = torch.tensor(R).float().to(device)
    T = torch.tensor(T).float().to(device)

if R.dim() == 2:
    R = R[None, ...]
    T = T[None, ...]

if isinstance(images, Image.Image):
    images =  transforms.functional.to_tensor(images).to(device)

if images.dim() == 3:
    images = images[None, ...]
    B = 1
else:
    B = images.shape[0]

if depths.dim() == 2:
    depths = depths[None, None, ...]
elif depths.dim() == 3:
        depths = depths.unsqueeze(1)

# unproj. / rotate / translate
# with torch.autocast(device, enabled=False):
coords = x_2d_coords(H, W, device=device)[None, ...].repeat(B, 1, 1, 1)
coords_3d = pi_inv(K, coords, depths[:,0,...])
coords_world = transpose(R, T, coords_3d)
coords_world = coords_world.reshape((-1, H, W, 3))
coords_world = rearrange(coords_world, 'b h w c -> b c (h w)')

 # proj.
proj_coords_3d =  K[:, :3, :3]@coords_world
proj_coords_3d = rearrange(proj_coords_3d, 'b c (h w) -> b h w c', h=H, w=W)
proj_coords = proj_coords_3d[..., :2]/(proj_coords_3d[..., 2:3]+1e-6)

# masking
mask = depths[:,0,...] == 0
proj_coords[mask] = -1000000 if proj_coords.dtype == torch.float32 else -1e+4
back_mask = proj_coords_3d[..., 2:3] <= 0
back_mask = back_mask.repeat(1, 1, 1, 2)
proj_coords[back_mask] = -1000000 if proj_coords.dtype == torch.float32 else -1e+4

# proj.
new_z = proj_coords_3d[..., 2:3].permute(0,3,1,2) # B H W 1 -> B 1 H W
flow = proj_coords - coords
flow = flow.permute(0,3,1,2) # B H W 2 -> B 2 H W

alpha = 0.5
importance = alpha/new_z
importance_min = importance.amin((1,2,3),keepdim=True)
importance_max = importance.amax((1,2,3),keepdim=True)
importance=(importance-importance_min)/(importance_max-importance_min+1e-6)*10-10
importance = importance.exp()
input_data = torch.cat([importance*images, importance*new_z, importance], 1)
output_data = splatting_function("summation", input_data, flow)

renderings = output_data[:,:-1,:,:] / (output_data[:,-1:,:,:]+1e-6)
renderings, warped_depths = renderings[:,:-1,:,:], renderings[:,-1:,:,:]
masks = (renderings == 0.).all(dim=1).int()

if torch.isnan(renderings).sum() > 0:
    print("NaNs in renderings")
    wandb.alert(title="NaNs in renderings", text="NaNs in renderings")
    renderings = torch.zeros_like(renderings)
    warped_depths = torch.zeros_like(warped_depths)

output = dict(
    warped= renderings,
    mask=mask.float(),
    correspondence=None
)

return output