Totoro97 / f2-nerf

Fast neural radiance field training with free camera trajectories
https://totoro97.github.io/projects/f2-nerf/
Apache License 2.0
933 stars 69 forks source link

Script/Explanation for figure 4 #32

Closed kwea123 closed 1 year ago

kwea123 commented 1 year ago

Do you have any script that can visualize something like figure 4, or any detailed explanations?

I wrote something that follows your formula, but get totally different results. I use a toy example where there are two 1D cameras, and I use F(x) = (C1(x), C2(x)) as warping function. The warping looks roughly correct for theta=30

截圖 2023-05-08 下午3 44 17

but totally different for theta=180

截圖 2023-05-08 下午3 41 01

Should I increase the number of cameras and really do PCA? I feel like it should work with only two cameras.

Totoro97 commented 1 year ago

Hello, thank you so much for your interest in this project! Below is my script for the warp space visualization. We conduct the warping in 3D space and the visualization the points are warped results from a plane. Sorry for the unclear description in the paper.

import numpy as np
import trimesh
from scipy.spatial.transform import Rotation as R
from sklearn.decomposition import PCA

# Define simple cameras
n_cams = 4
cam_dis = 3.0
N = 6
angle_degree = 89.9

def gen_cam_w2c():
    delta = angle_degree / 180. * np.pi
    theta = np.linspace(0., 2. * np.pi * (n_cams - 1) / n_cams, n_cams)
    axis = np.stack([
        np.zeros([n_cams]),
        np.cos(theta),
        np.sin(theta)
    ], axis=-1)

    rot = R.from_rotvec(axis * delta)

    cam_pos = np.zeros([n_cams, 3])
    cam_pos[:, 0] = cam_dis
    cam_pos = rot.apply(cam_pos)
    pcd = trimesh.PointCloud(cam_pos)
    pcd.export('./cam_pos.ply')

    z_vec = cam_pos
    y_vec = np.zeros([n_cams, 3])
    y_vec[:, 2] = 1.

    x_vec = np.cross(y_vec, z_vec)
    y_vec = np.cross(z_vec, x_vec)
    y_vec = y_vec / np.linalg.norm(y_vec, ord=2, axis=-1, keepdims=True)
    x_vec = x_vec / np.linalg.norm(x_vec, ord=2, axis=-1, keepdims=True)
    z_vec = z_vec / np.linalg.norm(z_vec, ord=2, axis=-1, keepdims=True)
    rot_w2c = np.stack([x_vec, y_vec, z_vec], axis=1)
    rot_c2w = np.linalg.inv(rot_w2c)
    c2w = np.zeros([n_cams, 4, 4])
    c2w[:, :3, :3] = rot_c2w
    c2w[:, :3, 3] = cam_pos
    c2w[:, 3, 3] = 1.

    w2c = np.linalg.inv(c2w)
    return w2c

def gen_pts():
    st = np.linspace(-1., 1., N)
    xx, yy, zz = np.meshgrid(st, st, st)

    pts = np.stack([xx, yy, zz], -1).reshape([-1, 3])
    pcd = trimesh.PointCloud(pts)
    pcd.export('./volume.ply')
    return pts

def hello():
    w2c = gen_cam_w2c()  # [n_cams, 4, 4]
    pts = gen_pts()      # [n_pts, 3]
    n_cams = len(w2c)
    n_pts = len(pts)

    # Points at camera spaces
    pts_cam = w2c[None, :, :3, :3] @ pts[:, None, :, None] + w2c[None, :, :3, 3:]
    pts_cam = pts_cam[..., 0]  # [n_pts, n_cams, 3]

    dcam_dxyz = w2c[None, :, :3, :3]
    du_dcam_xy = np.concatenate([pts_cam[:, :, 2:3], np.zeros_like(pts_cam[:, :, 2:3])], 2)
    dv_dcam_xy = np.concatenate([np.zeros_like(pts_cam[:, :, 2:3]), pts_cam[:, :, 2:3]], 2)
    duv_dcam_xy = np.stack([du_dcam_xy, dv_dcam_xy], 2)
    duv_dcam_z = -pts_cam[:, :, :2] / pts_cam[:, :, 2:3]**2
    duv_dcam = np.concatenate([duv_dcam_xy, duv_dcam_z[..., None]], -1)  # [n_pts, n_cams, 2, 3]

    # Points at image spaces
    pts_uv = pts_cam[:, :, :2] / pts_cam[:, :, 2:3]   # [n_pts, n_cams, 2]

    cat_coords = pts_uv.reshape([n_pts, -1])              # [n_pts, n_cams * 2]
    duv_dxyz = duv_dcam @ dcam_dxyz

    # Jacobian matrices from world space to concatenated image space
    dcoords_dxyz = duv_dxyz.reshape([n_pts, n_cams * 2, -1])  # [n_pts, n_cams * 2, 3]

    pca = PCA(n_components=3)
    pca.fit(cat_coords)
    lin = pca.components_

    # Jacobian matrices from world space to warp' space
    dwarp_dxyz = lin[None] @ dcoords_dxyz
    dxyz_dwarp = np.linalg.inv(dwarp_dxyz)

    # Estimated jacobian matrices from warp' space to concatenated image space
    dcoords_dwarp = dcoords_dxyz @ dxyz_dwarp

    # This value means how long we move along the axes in the warp space corresponds to the "unit length"
    # in the image space
    expect_warp_axis_scale = 1. / np.max(np.abs(dcoords_dwarp), axis=1)
    mean_axis_scale = expect_warp_axis_scale.mean(0)

    # Rescale the linear mapping matrix
    # Because we expect the "unit length" along the axes in the warp space approximately equal to the
    # "unit length" in the image space
    lin = lin / mean_axis_scale[:, None]
    plane_cat_coords = cat_coords.reshape([N, N, N, n_cams * 2])[:, :, N // 2, :].reshape(-1, n_cams * 2)
    volume_cat_coords = cat_coords.reshape([N, N, N, n_cams * 2]).reshape(-1, n_cams * 2)

    warped_plane = (lin[None] @ plane_cat_coords[..., None])[..., 0]
    warped_volume = (lin[None] @ volume_cat_coords[..., None])[..., 0]
    colors_a = np.linspace(0.1, 0.7, N)
    colors_b = np.linspace(0.7, 0.1, N)
    aa, bb = np.meshgrid(colors_a, colors_b)
    aa = aa.reshape(-1)
    bb = bb.reshape(-1)
    colors = np.stack([aa, bb, np.ones_like(aa) * 0.6], -1)

    plane_pcd = trimesh.PointCloud(pts.reshape([N, N, N, 3])[:, :, N // 2, :].reshape(-1, 3),  vertex_colors=colors)
    plane_pcd.export('./plane.ply')
    warped_plane_pcd = trimesh.PointCloud(warped_plane, vertex_colors=colors)
    warped_plane_pcd.export('./warped_plane.ply')
    warped_volume_pcd = trimesh.PointCloud(warped_volume)
    warped_volume_pcd.export('./warped_volume.ply')

if __name__ == '__main__':
    hello()
kwea123 commented 1 year ago

I tried reducing the number of cameras to 2, and compare my result (2D) against yours (3D)

截圖 2023-05-08 下午9 53 28 截圖 2023-05-08 下午9 53 51

I think my implementation is correct as the shape is similar. So I have two insights here:

  1. We need more cameras (e.g. 4 in the paper) to keep the warped space "regular", i.e. the uniform grid becomes approximately still a uniform grid, not too distorted like here.
  2. The cameras also need to be roughly evenly spaced in 3D, otherwise like my 2D example in the above comment, the warped grid is very irregular unlike the result here (above I just randomly offset the cameras to the right, here I make them centered and symmetric). I also noticed that your script generates cameras that are evenly spaced. In reality I don't think this case happen often, have you dug into how the warped grid look like if the camera positions are more irrregular? I think it's interesting.

Thanks for the insightful work!

kwea123 commented 1 year ago

also just out of curiosity why the config is called wanjinyou :laughing:

Totoro97 commented 1 year ago

Hi, thanks so much for your thoughts! They are very insightful. I have yet to thoroughly look at the visualizations of the warped volumes of all kinds of camera distributions; you are correct; it would be interesting to see what they look like.

wanjinyou means "万金油/萬金油". Because I hope it does not fail in synthesizing novel views in training from diverse kinds of capture trajectories, as long as SfM is successfully run and the number of training images is not too large. Though there is still a large room to improve the synthesized quality😳

kwea123 commented 1 year ago

I'll leave my script here too https://gist.github.com/kwea123/6ef313d976c97c43d43d8e710572493c