hbb1 / 2d-gaussian-splatting

[SIGGRAPH'24] 2D Gaussian Splatting for Geometrically Accurate Radiance Fields
https://surfsplatting.github.io
Other
1.56k stars 75 forks source link

Poor results on BMVS dataset #58

Open aaryapatel007 opened 4 weeks ago

aaryapatel007 commented 4 weeks ago

Hi @hbb1,

I tested your model on the BMVS dataset. Unfortunately, I wasn't able to get good results. I used the code provided in #5 and #44 to convert IDR format data to COLMAP format. I am not sure what wrong I have done on my part. Can you help me figure it out?

My idr_to_colmap conversion script:

import os
import numpy as np
import json
import cv2
import sys
from pathlib import Path
from argparse import ArgumentParser
import trimesh

dir_path = Path(os.path.dirname(os.path.realpath(__file__))).parents[2]
sys.path.append(dir_path.__str__())

from convert_data_to_json import export_to_json  # NOQA

from database import COLMAPDatabase  # NOQA
from read_write_model import read_model, rotmat2qvec  # NOQA

def load_K_Rt_from_P(filename, P=None):
    # This function is borrowed from IDR: https://github.com/lioryariv/idr
    if P is None:
        lines = open(filename).read().splitlines()
        if len(lines) == 4:
            lines = lines[1:]
        lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
        P = np.asarray(lines).astype(np.float32).squeeze()

    out = cv2.decomposeProjectionMatrix(P)
    K = out[0]
    R = out[1]
    t = out[2]

    K = K / K[2, 2]
    intrinsics = np.eye(4)
    intrinsics[:3, :3] = K

    pose = np.eye(4, dtype=np.float32)
    pose[:3, :3] = R.transpose()
    pose[:3, 3] = (t[:3] / t[3])[:, 0]

    return intrinsics, pose

def create_init_files(pinhole_dict_file, db_file, out_dir):
    # Partially adapted from https://github.com/Kai-46/nerfplusplus/blob/master/colmap_runner/run_colmap_posed.py

    if not os.path.exists(out_dir):
        os.mkdir(out_dir)

    # create template
    with open(pinhole_dict_file) as fp:
        pinhole_dict = json.load(fp)

    template = {}
    cameras_line_template = '{camera_id} RADIAL {width} {height} {f} {cx} {cy} {k1} {k2}\n'
    images_line_template = '{image_id} {qw} {qx} {qy} {qz} {tx} {ty} {tz} {camera_id} {image_name}\n\n'

    for img_name in pinhole_dict:
        # w, h, fx, fy, cx, cy, qvec, t
        params = pinhole_dict[img_name]
        w = params[0]
        h = params[1]
        fx = params[2]
        # fy = params[3]
        cx = params[4]
        cy = params[5]
        qvec = params[6:10]
        tvec = params[10:13]

        cam_line = cameras_line_template.format(
            camera_id="{camera_id}", width=w, height=h, f=fx, cx=cx, cy=cy, k1=0, k2=0)
        img_line = images_line_template.format(image_id="{image_id}", qw=qvec[0], qx=qvec[1], qy=qvec[2], qz=qvec[3],
                                               tx=tvec[0], ty=tvec[1], tz=tvec[2], camera_id="{camera_id}",
                                               image_name=img_name)
        template[img_name] = (cam_line, img_line)

    # read database
    db = COLMAPDatabase.connect(db_file)
    table_images = db.execute("SELECT * FROM images")
    img_name2id_dict = {}
    for row in table_images:
        img_name2id_dict[row[1]] = row[0]

    cameras_txt_lines = [template[img_name][0].format(camera_id=1)]
    images_txt_lines = []
    for img_name, img_id in img_name2id_dict.items():
        image_line = template[img_name][1].format(image_id=img_id, camera_id=1)
        images_txt_lines.append(image_line)

    with open(os.path.join(out_dir, 'cameras.txt'), 'w') as fp:
        fp.writelines(cameras_txt_lines)

    with open(os.path.join(out_dir, 'images.txt'), 'w') as fp:
        fp.writelines(images_txt_lines)
        fp.write('\n')

    # create an empty points3D.txt
    fp = open(os.path.join(out_dir, 'points3D.txt'), 'w')
    fp.close()

def convert_cam_dict_to_pinhole_dict(camera_dict, pinhole_dict_file, img_names):
    # Partially adapted from https://github.com/Kai-46/nerfplusplus/blob/master/colmap_runner/run_colmap_posed.py

    print('Writing pinhole_dict to: ', pinhole_dict_file)
    h = 1200
    w = 1600
    pinhole_dict = {}
    for idx in range(0, len(camera_dict)//6):
        world_mat = camera_dict['world_mat_%d' % idx].astype(np.float32)
        scale_mat = camera_dict['scale_mat_%d' % idx].astype(np.float32)
        P = world_mat @ scale_mat
        P = P[:3, :4]
        K, pose_c2w = load_K_Rt_from_P(None, P)

        W2C  = np.linalg.inv(pose_c2w)

        # params
        fx = float(K[0, 0])
        fy = float(K[1, 1])
        cx = float(K[0, 2])
        cy = float(K[1, 2])
        qvec = rotmat2qvec(W2C[:3, :3])  # Assuming this function returns a list of floats
        tvec = [float(x) for x in W2C[:3, 3]]

        params = [w, h, fx, fy, cx, cy,
                  qvec[0], qvec[1], qvec[2], qvec[3],
                  tvec[0], tvec[1], tvec[2]]
        pinhole_dict[img_names[idx]] = params

    with open(pinhole_dict_file, 'w') as fp:
        json.dump(pinhole_dict, fp, indent=2, sort_keys=True)

def init_colmap(args):
    assert args.dtu_path, "Provide path to Tanks and Temples dataset"
    scene_list = os.listdir(args.dtu_path)

    for scene in scene_list:
        scene_path = os.path.join(args.dtu_path, scene)

        if not os.path.exists(f"{scene_path}/image"):
            raise Exception(f"'image` folder cannot be found in {scene_path}."
                            "Please check the expected folder structure in DATA_PREPROCESSING.md")

        # extract features
        os.system(f"colmap feature_extractor --database_path {scene_path}/database.db \
                --image_path {scene_path}/image \
                --ImageReader.camera_model=PINHOLE \
                --SiftExtraction.use_gpu=true \
                --SiftExtraction.num_threads=32 \
                --ImageReader.single_camera=true"
                  )

        # match features
        os.system(f"colmap sequential_matcher \
                --database_path {scene_path}/database.db \
                --SiftMatching.use_gpu=true"
                  )

        # read poses
        camera_dict = np.load(os.path.join(scene_path, 'cameras.npz'))
        # convert to colmap files
        pinhole_dict_file = os.path.join(scene_path, 'pinhole_dict.json')
        convert_cam_dict_to_pinhole_dict(camera_dict, pinhole_dict_file,img_names=os.listdir(os.path.join(scene_path, 'image')))

        db_file = os.path.join(scene_path, 'database.db')
        sfm_dir = os.path.join(scene_path, 'sparse')
        create_init_files(pinhole_dict_file, db_file, sfm_dir)

        # bundle adjustment
        os.system(f"colmap point_triangulator \
                --database_path {scene_path}/database.db \
                --image_path {scene_path}/image \
                --input_path {scene_path}/sparse \
                --output_path {scene_path}/sparse \
                --Mapper.tri_ignore_two_view_tracks=true"
                  )
        os.system(f"colmap bundle_adjuster \
                --input_path {scene_path}/sparse \
                --output_path {scene_path}/sparse \
                --BundleAdjustment.refine_extrinsics=false"
                  )

def load_COLMAP_poses(cam_file, img_dir, tf='w2c'):
    # load img_dir namges
    names = sorted(os.listdir(img_dir))

    with open(cam_file) as f:
        lines = f.readlines()

    # C2W
    poses = {}
    for idx, line in enumerate(lines):
        if idx % 5 == 0:  # header
            img_idx, valid, _ = line.split(' ')
            if valid != '-1':
                poses[int(img_idx)] = np.eye(4)
                poses[int(img_idx)]
        else:
            if int(img_idx) in poses:
                num = np.array([float(n) for n in line.split(' ')])
                poses[int(img_idx)][idx % 5-1, :] = num

    if tf == 'c2w':
        return poses
    else:
        # convert to W2C (follow nerf convention)
        poses_w2c = {}
        for k, v in poses.items():
            poses_w2c[names[k]] = np.linalg.inv(v)
        return poses_w2c

def load_transformation(trans_file):
    with open(trans_file) as f:
        lines = f.readlines()

    trans = np.eye(4)
    for idx, line in enumerate(lines):
        num = np.array([float(n) for n in line.split(' ')])
        trans[idx, :] = num

    return trans

def align_gt_with_cam(pts, trans):
    trans_inv = np.linalg.inv(trans)
    pts_aligned = pts @ trans_inv[:3, :3].transpose(-1, -2) + trans_inv[:3, -1]
    return pts_aligned

def compute_bound(pts):
    bounding_box = np.array([pts.min(axis=0), pts.max(axis=0)])
    center = bounding_box.mean(axis=0)
    radius = np.max(np.linalg.norm(pts - center, axis=-1)) * 1.01
    return center, radius, bounding_box.T.tolist()

if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('--dtu_path', type=str, default=None, help='Path to tanks and temples dataset')

    args = parser.parse_args()

    init_colmap(args)

Reconstruction of the BMVS Man object: image

hbb1 commented 4 weeks ago

Did you preprocess the camera correctly? what is the rendering outcome?

aaryapatel007 commented 4 weeks ago

The training rendering outcome (train/renders) is good but the novel view synthesis (traj/renders) output is poor.

training rendering outcome (train/renders) 00014

novel view synthesis (traj/renders) output: 00017

hbb1 commented 4 weeks ago

Is your camera's focus centralized?

aaryapatel007 commented 4 weeks ago

How do I check that?

FYI, The BMVS Man object input images are very zoomed in, often cropping the man object from top and bottom.

hbb1 commented 4 weeks ago

if cx (cy) == width // 2 (height // 2)? Check here. https://github.com/graphdeco-inria/gaussian-splatting/issues/144#issuecomment-1938504456

aaryapatel007 commented 4 weeks ago

OK, I was using your default getProjectionMatrix() function in graphics_util.py. Should I update it to getProjectionMatrixShift() as given here and train again to ensure principal points pass through the center of the image?

hbb1 commented 4 weeks ago

it depends on whether you use the ideal pinhole camera (centralized).

aaryapatel007 commented 4 weeks ago

While running the extract feature in COLMAP, I'd set camera_model=PINHOLE.

os.system(f"colmap feature_extractor --database_path {scene_path}/database.db \
                --image_path {scene_path}/image \
                --ImageReader.camera_model=PINHOLE \
                --SiftExtraction.use_gpu=true \
                --SiftExtraction.num_threads=32 \
                --ImageReader.single_camera=true"
                  )
hbb1 commented 4 weeks ago

The key is that you should ensure your camera principle point lies in the center. Please make sure that your cx cy is half of the image width and height. Otherwise, you should modify the ProjectionMatrix with shift.