TQTQliu / MVSGaussian

[ECCV 2024] MVSGaussian: Fast Generalizable Gaussian Splatting Reconstruction from Multi-View Stereo
https://mvsgaussian.github.io/
MIT License
416 stars 21 forks source link

Weird results on a custom Blender-synthetic dataset #33

Closed Runyu-Zhou05 closed 3 months ago

Runyu-Zhou05 commented 3 months ago

Hi there! Thanks for your great work! I've encountered some unexpected behaviors when evaluating the model pretrained on the DTU dataset (the default one). My dataset consists of ~200 views of a Blender-synthesized LEGO object, with some of the images shown below: image The original sizes of the images are 800x800. The poses are specified in a transforms_train.json file just like the NeRF-synthetic dataset. I've combined the dataset modules for colmap and nerf to obtain that of this dataset. Here is my config file:

parent_cfg: configs/mvsgs/dtu_pretrain.yaml

train_dataset_module: lib.datasets.mydataset.mvsgs
test_dataset_module: lib.datasets.mydataset.mvsgs

mvsgs:
    bg_color: [1, 1, 1]
    test_input_views: 3
    eval_center: True
    reweighting: True
    scale_factor: 12
    cas_config:
        render_if: [False, True]
        volume_planes: [16, 8]

train_dataset:
    data_root: 'examples'
    split: 'train'
    input_h_w: [640, 640]
    input_ratio: 1.

test_dataset:
    data_root: 'examples'
    split: 'test'
    input_h_w: [640, 640]
    input_ratio: 1.

The reason why I use 640 for the input size is because 800 will cause out-of-memory. I failed to run COLMAP on this dataset due to lack of features, so I used the same near-far settings as the nerf dataset:

        H, W = tar_img.shape[:2]
        near_far = np.array([2.5 * self.scale_factor, 5.5 * self.scale_factor]).astype(np.float32)
        ret.update({'near_far': np.array(near_far).astype(np.float32)})

The near-far settings seem to play a crucial role since when I set near=0.1 and far=10, the PSNR is only ~8.

Also I noticed that when the input size is set to 400, I got this error:

  File ".../MVSGaussian/lib/networks/mvsgs/cost_reg_net.py", line 80, in forward
    x = conv2 + self.conv9(x)
RuntimeError: The size of tensor a (25) must match the size of tensor b (26) at non-singleton dimension 4

Anyway, when input size = 640, I got PSNR of about 14 and these test results:

image image image image

Have you ever encountered similar problems? Should I try other near-far settings or condition-views selection approaches?

Additionally, in case of need, here is my dataset module:

import numpy as np
import os
from lib.config import cfg
import imageio
import cv2
import random
from lib.utils import data_utils
import torch
import json
from lib.datasets import mvsgs_utils
from lib.utils.video_utils import *

class Dataset:
    def __init__(self, **kwargs):
        super(Dataset, self).__init__()
        self.data_root: str = os.path.join(cfg.workspace, kwargs['data_root'])
            # workspace: .
            # data_root: just data root
        self.split = kwargs['split']
        self.input_h_w = kwargs['input_h_w']
        self.scale_factor = cfg.mvsgs.scale_factor # can be 12
        self.build_metas()
        self.zfar = 100.0
        self.znear = 0.01
        self.trans = [0.0, 0.0, 0.0]
        self.scale = 1.0

    def build_metas(self):
        self.scene_infos = {}
        self.metas = []
        self.data_root = self.data_root.strip()
        if self.data_root.endswith('/'):
            self.data_root = self.data_root[:-1]
        scene = os.path.basename(self.data_root)
        json_info = json.load(open(os.path.join(self.data_root, 'transforms_train.json')))
        b2c = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
        scene_info = {'ixts': [], 'exts': [], 'img_paths': []}
        frames: list = json_info['frames']
        frames.sort(key=lambda x: int(os.path.basename(x['file_path'])))
        for idx, frame in enumerate(frames):
            c2w = np.array(frame['transform_matrix'])
            c2w = c2w @ b2c
            ext = np.linalg.inv(c2w)
            ixt = np.eye(3)
            ixt[0][2], ixt[1][2] = 400, 400
            focal = .5 * 800 / np.tan(.5 * json_info['camera_angle_x'])
            ixt[0][0], ixt[1][1] = focal, focal
            scene_info['ixts'].append(ixt.astype(np.float32))
            scene_info['exts'].append(ext.astype(np.float32))
            img_path = os.path.join(self.data_root, frame['file_path'] + '.png')
            scene_info['img_paths'].append(img_path)
        img_len = len(frames)
        #
        render_ids = [j for j in range(img_len//8, img_len, img_len//4)] # test views
        train_ids = [j for j in range(img_len) if j not in render_ids]
        #
        if self.split == 'train':
            render_ids = train_ids
        c2ws = np.stack([np.linalg.inv(scene_info['exts'][idx]) for idx in train_ids])
        scene_info['c2ws'] = c2ws.astype(np.float32)
        self.scene_infos[scene] = scene_info

        for i in render_ids: # condition views
            c2w = scene_info['c2ws'][i]
            distance = np.linalg.norm((c2w[:3, 3][None] - c2ws[:, :3, 3]), axis=-1)
            argsorts = distance.argsort()
            argsorts = argsorts[1:] if i in train_ids else argsorts
            if self.split == 'train':
                src_views = [train_ids[j] for j in argsorts[:cfg.mvsgs.train_input_views[1]+1]]
            else:
                src_views = [train_ids[j] for j in argsorts[:cfg.mvsgs.test_input_views]]
            self.metas += [(scene, i, src_views)]

    def get_video_rendering_path(self, ref_poses, mode, near_far, train_c2w_all, n_frames=60, rads_scale=1.25):
        # loop over batch
        poses_paths = []
        ref_poses = ref_poses[None]
        for batch_idx, cur_src_poses in enumerate(ref_poses):
            if mode == 'interpolate':
                # convert to c2ws
                pose_square = torch.eye(4).unsqueeze(0).repeat(cur_src_poses.shape[0], 1, 1)
                cur_src_poses = torch.from_numpy(cur_src_poses)
                pose_square[:, :3, :] = cur_src_poses[:,:3]
                cur_c2ws = pose_square.double().inverse()[:, :3, :].to(torch.float32).cpu().detach().numpy()
                cur_path = get_interpolate_render_path(cur_c2ws, n_frames)
            elif mode == 'spiral':
                cur_c2ws_all = train_c2w_all
                cur_near_far = near_far.tolist()
                # rads_scale=...?
                cur_path = get_spiral_render_path(cur_c2ws_all, cur_near_far, rads_scale=rads_scale, N_views=n_frames)
            else:
                raise Exception(f'Unknown video rendering path mode {mode}')

            # convert back to extrinsics tensor
            cur_w2cs = torch.tensor(cur_path).inverse()[:, :3].to(torch.float32)
            poses_paths.append(cur_w2cs)

        poses_paths = torch.stack(poses_paths, dim=0)
        return poses_paths

    def __getitem__(self, index_meta):
        index, input_views_num = index_meta
        scene, tar_view, src_views = self.metas[index]
        if self.split == 'train':
            if np.random.random() < 0.1:
                src_views = src_views + [tar_view]
            src_views = random.sample(src_views, input_views_num)
        scene_info = self.scene_infos[scene]
        tar_img, tar_mask, tar_ext, tar_ixt = self.read_tar(scene_info, tar_view)
        src_inps, src_exts, src_ixts = self.read_src(scene_info, src_views)
        ret = {'src_inps': src_inps.transpose(0, 3, 1, 2),
               'src_exts': src_exts,
               'src_ixts': src_ixts}
        ret.update({'tar_ext': tar_ext,
                    'tar_ixt': tar_ixt})
        if self.split != 'train':
            ret.update({'tar_img': tar_img,
                        'tar_mask': tar_mask})

        H, W = tar_img.shape[:2]
        near_far = np.array([2.5 * self.scale_factor, 5.5 * self.scale_factor]).astype(np.float32)
        ret.update({'near_far': np.array(near_far).astype(np.float32)})
        ret.update({'meta': {'scene': scene, 'tar_view': tar_view, 'frame_id': 0}})

        for i in range(cfg.mvsgs.cas_config.num):
            rays, rgb, msk = mvsgs_utils.build_rays(tar_img, tar_ext, tar_ixt, tar_mask, i, self.split)
            ret.update({f'rays_{i}': rays, f'rgb_{i}': rgb.astype(np.float32), f'msk_{i}': msk})
            s = cfg.mvsgs.cas_config.volume_scale[i]
            ret['meta'].update({f'h_{i}': int(H*s), f'w_{i}': int(W*s)})

        R = np.array(tar_ext[:3, :3], np.float32).reshape(3, 3).transpose(1, 0)
        T = np.array(tar_ext[:3, 3], np.float32)
        for i in range(cfg.mvsgs.cas_config.num):
            h, w = H*cfg.mvsgs.cas_config.render_scale[i], W*cfg.mvsgs.cas_config.render_scale[i]
            tar_ixt_ = tar_ixt.copy()
            tar_ixt_[:2,:] *= cfg.mvsgs.cas_config.render_scale[i]
            FovX = data_utils.focal2fov(tar_ixt_[0, 0], w)
            FovY = data_utils.focal2fov(tar_ixt_[1, 1], h)
            projection_matrix = data_utils.getProjectionMatrix(znear=self.znear, zfar=self.zfar, K=tar_ixt_, h=h, w=w).transpose(0, 1)
            world_view_transform = torch.tensor(data_utils.getWorld2View2(R, T, np.array(self.trans), self.scale)).transpose(0, 1)
            full_proj_transform = (world_view_transform.unsqueeze(0).bmm(projection_matrix.unsqueeze(0))).squeeze(0)
            camera_center = world_view_transform.inverse()[3, :3]
            novel_view_data = {
                'FovX':  torch.FloatTensor([FovX]),
                'FovY':  torch.FloatTensor([FovY]),
                'width': w,
                'height': h,
                'world_view_transform': world_view_transform,
                'full_proj_transform': full_proj_transform,
                'camera_center': camera_center
            }
            ret[f'novel_view{i}'] = novel_view_data    

        if cfg.save_video:
            rendering_video_meta = []
            render_path_mode = 'spiral'
            train_c2w_all = np.linalg.inv(src_exts)
            poses_paths = self.get_video_rendering_path(src_exts, render_path_mode, near_far, train_c2w_all, n_frames=60)
            for pose in poses_paths[0]:
                R = np.array(pose[:3, :3], np.float32).reshape(3, 3).transpose(1, 0)
                T = np.array(pose[:3, 3], np.float32)
                FovX = data_utils.focal2fov(tar_ixt[0, 0], W)
                FovY = data_utils.focal2fov(tar_ixt[1, 1], H)
                projection_matrix = data_utils.getProjectionMatrix(znear=self.znear, zfar=self.zfar, K=tar_ixt, h=H, w=W).transpose(0, 1)
                world_view_transform = torch.tensor(data_utils.getWorld2View2(R, T, np.array(self.trans), self.scale)).transpose(0, 1)
                full_proj_transform = (world_view_transform.unsqueeze(0).bmm(projection_matrix.unsqueeze(0))).squeeze(0)
                camera_center = world_view_transform.inverse()[3, :3]
                rendering_meta = {
                    'FovX':  torch.FloatTensor([FovX]),
                    'FovY':  torch.FloatTensor([FovY]),
                    'width': W,
                    'height': H,
                    'world_view_transform': world_view_transform,
                    'full_proj_transform': full_proj_transform,
                    'camera_center': camera_center,
                    'tar_ext': pose
                }
                for i in range(cfg.mvsgs.cas_config.num):
                    tar_ext[:3] = pose
                    rays, _, _ = mvsgs_utils.build_rays(tar_img, tar_ext, tar_ixt, tar_mask, i, self.split)
                    rendering_meta.update({f'rays_{i}': rays})
                rendering_video_meta.append(rendering_meta)
            ret['rendering_video_meta'] = rendering_video_meta
        return ret

    def read_src(self, scene, src_views):
        src_ids = src_views
        ixts, exts, imgs = [], [], []
        for idx in src_ids:
            img, orig_size = self.read_image(scene, idx)
            imgs.append(((img/255.)*2-1).astype(np.float32))
            ixt, ext = self.read_cam(scene, idx, orig_size)
            ixts.append(ixt)
            exts.append(ext)
        return np.stack(imgs), np.stack(exts), np.stack(ixts)

    def read_tar(self, scene, view_idx):
        img, orig_size = self.read_image(scene, view_idx)
        img = (img/255.).astype(np.float32)
        ixt, ext = self.read_cam(scene, view_idx, orig_size)
        mask = np.ones_like(img[..., 0]).astype(np.uint8)
        return img, mask, ext, ixt

    def read_cam(self, scene, view_idx, orig_size):
        ext = scene['exts'][view_idx].astype(np.float32)
        ext[:3,3] *= self.scale_factor 
        ixt = scene['ixts'][view_idx]
        # ixt[0, 2] = self.input_h_w[1] / 2
        # ixt[1, 2] = self.input_h_w[0] / 2
        ixt[0] *= self.input_h_w[1] / orig_size[0]
        ixt[1] *= self.input_h_w[0] / orig_size[1]
        return ixt, ext

    def read_image(self, scene, view_idx):
        img_path = scene['img_paths'][view_idx]
        img = (np.array(imageio.imread(img_path))).astype(np.float32)
        orig_size = img.shape[:2][::-1]
        img = cv2.resize(img, self.input_h_w[::-1], interpolation=cv2.INTER_AREA)
        return np.array(img), orig_size

    def __len__(self):
        return len(self.metas)

def get_K_from_params(params):
    K = np.zeros((3, 3)).astype(np.float32)
    K[0][0], K[0][2], K[1][2] = params[:3]
    K[1][1] = K[0][0]
    K[2][2] = 1.
    return K
TQTQliu commented 3 months ago

Hi, thanks for your interest. Your data have a large difference in viewpoint, which is a challenge for our model.This is why COLMAP cannot be run on this dataset due to lack of features. However, here are some things that might be worth trying: i) I noticed that you set bg_color: [1,1,1] considering that your data background is white. However, for our nerf-synthetic dataset, which also has a white background, setting bg_color: [1,1,1] will cause performance degradation (PSNR from 26.46 to 23.39). This is because we pre-train on the dtu data set with a setting of bg_color: [0,0,0], so I recommend you try bg_color: [0,0,0]. ii) If you used the same near-far settings as the nerf dataset near_far = np.array([2.5 * self.scale_factor, 5.5 * self.scale_factor]).astype(np.float32), the scale_factor in the config file should also be modified to 200 , the same as in nerf-synthetic config file. iii) You can also modify the number of sampling points volume_planes in the config file, commonly used settings are [64,8], [48,8] and [16,8].

BTW, the cause of the error when the input size is set to 400 is that H and W need to be integer multiples of 32.

Runyu-Zhou05 commented 3 months ago

Thanks a lot for your informative response. Through extensive experimentation, I've improved my PSNR a lot but still have some issues with the exported PLY file. (Having a large difference in viewpoint isn't actually the issue, since my dataset has ample views and nuanced difference between neighboring views.)

Improving PSNR

For improving PSNR, I've discovered:

Finally, I've got a PSNR of ~25 and these rendered images:

image image image image

Exporting PLY

However, I've got weird PLY pointcloud despite reasonable PSNR. The exported PLY file has only ~300KB in size, and here is the visualization:

image

My command was:

python run.py --type evaluate --cfg_file configs/mvsgs/mydataset_eval.yaml save_ply True dir_ply models/mvsgs test_dataset.data_root mydataset

Should I tweak something further?

TQTQliu commented 3 months ago

Glad to hear that you have made a big improvement on PSNR, and thank you for sharing.

Run this command python run.py --type evaluate --cfg_file configs/mvsgs/mydataset_eval.yaml save_ply True dir_ply models/mvsgs test_dataset.data_root mydataset can obtain a normal geometric point cloud, but not a Gaussian point cloud. This point cloud will be used as the initialization for subsequent per-scene optimization, and then you can get the optimized Gaussian point cloud.

Please note that we clarified this in the readme, see here. And we updated two files: lib/networks/mvsgs/network.py and lib/scene/dataset_readers.py.

Runyu-Zhou05 commented 3 months ago

Thanks a million!