Open mohammadasim98 opened 1 week ago
Hello. Thank you for your interest in our work. We are planning to update soon to support R|t, K conditions that would be more familiar to researchers. In the meantime, if you'd like to implement this, you can temporarily input R, t, and K into the model by modifying our forward warping function here as shown below. Please note that the below code is unorganized and might need some debugging - sorry for this.
def forward_warper(images, depths, R, T, K=None, H=512, W=512, device=None):
def pi_inv(K, x, d):
fx, fy, cx, cy = K[:, 0:1, 0:1], K[:, 1:2,
1:2], K[:, 0:1, 2:3], K[:, 1:2, 2:3]
X_x = d * (x[..., 0] - cx) / fx
X_y = d * (x[..., 1] - cy) / fy
X_z = d
X = torch.stack([X_x, X_y, X_z], dim=-1)
return X
def x_2d_coords(h, w, device):
x_2d = torch.zeros((h, w, 2), device=device)
for y in range(0, h):
x_2d[y, :, 1] = y
for x in range(0, w):
x_2d[:, x, 0] = x
return x_2d
def transpose(R, t, X):
b, h, w, c = X.shape
X = rearrange(X, 'b h w c -> b c (h w)')
X_after_R = R@X + t[:, :, None]
X_after_R = rearrange(X_after_R, 'b c (h w) -> b h w c', h=h)
return X_after_R
if K is None:
focal = (5.8269e+02, 5.8269e+02)
K = torch.tensor([
[focal[0], 0., W/2],
[0., focal[1], H/2],
[0., 0., 1.]], device=device)
K = K[None, ...]
if isinstance(depths, np.ndarray):
depths = torch.tensor(depths).to(device)
if isinstance(R, np.ndarray):
R = torch.tensor(R).float().to(device)
T = torch.tensor(T).float().to(device)
if R.dim() == 2:
R = R[None, ...]
T = T[None, ...]
if isinstance(images, Image.Image):
images = transforms.functional.to_tensor(images).to(device)
if images.dim() == 3:
images = images[None, ...]
B = 1
else:
B = images.shape[0]
if depths.dim() == 2:
depths = depths[None, None, ...]
elif depths.dim() == 3:
depths = depths.unsqueeze(1)
# unproj. / rotate / translate
# with torch.autocast(device, enabled=False):
coords = x_2d_coords(H, W, device=device)[None, ...].repeat(B, 1, 1, 1)
coords_3d = pi_inv(K, coords, depths[:,0,...])
coords_world = transpose(R, T, coords_3d)
coords_world = coords_world.reshape((-1, H, W, 3))
coords_world = rearrange(coords_world, 'b h w c -> b c (h w)')
# proj.
proj_coords_3d = K[:, :3, :3]@coords_world
proj_coords_3d = rearrange(proj_coords_3d, 'b c (h w) -> b h w c', h=H, w=W)
proj_coords = proj_coords_3d[..., :2]/(proj_coords_3d[..., 2:3]+1e-6)
# masking
mask = depths[:,0,...] == 0
proj_coords[mask] = -1000000 if proj_coords.dtype == torch.float32 else -1e+4
back_mask = proj_coords_3d[..., 2:3] <= 0
back_mask = back_mask.repeat(1, 1, 1, 2)
proj_coords[back_mask] = -1000000 if proj_coords.dtype == torch.float32 else -1e+4
# proj.
new_z = proj_coords_3d[..., 2:3].permute(0,3,1,2) # B H W 1 -> B 1 H W
flow = proj_coords - coords
flow = flow.permute(0,3,1,2) # B H W 2 -> B 2 H W
alpha = 0.5
importance = alpha/new_z
importance_min = importance.amin((1,2,3),keepdim=True)
importance_max = importance.amax((1,2,3),keepdim=True)
importance=(importance-importance_min)/(importance_max-importance_min+1e-6)*10-10
importance = importance.exp()
input_data = torch.cat([importance*images, importance*new_z, importance], 1)
output_data = splatting_function("summation", input_data, flow)
renderings = output_data[:,:-1,:,:] / (output_data[:,-1:,:,:]+1e-6)
renderings, warped_depths = renderings[:,:-1,:,:], renderings[:,-1:,:,:]
masks = (renderings == 0.).all(dim=1).int()
if torch.isnan(renderings).sum() > 0:
print("NaNs in renderings")
wandb.alert(title="NaNs in renderings", text="NaNs in renderings")
renderings = torch.zeros_like(renderings)
warped_depths = torch.zeros_like(warped_depths)
output = dict(
warped= renderings,
mask=mask.float(),
correspondence=None
)
return output
Hi,
First of all, great work! It is very impressive. However I lack some information about the camera convention you are using i.e., opencv/colmap, Opengl etc. I would like to know the exact camera convention used for the extrinsics i.e., relative view matrix.
Thank you,
Best, Asm