yzslab / gaussian-splatting-lightning

A 3D Gaussian Splatting framework with various derived algorithms and an interactive web viewer
Other
457 stars 38 forks source link

how to render specify viewport images? #55

Closed insomniaaac closed 3 weeks ago

insomniaaac commented 3 weeks ago

thanks for maintaining such a good framework!!!

after training, i got some ckpts the next step i want to render images from given cameras.json file

[{"id": 0, "img_name": "frame_00000000.jpg", "width": 1080, "height": 1920, "position": [2.997182014760931, -0.4078853954204856, 0.14653383040844226], "rotation": [[0.049752599033574624, 0.6601904095224718, -0.7494486653960731], [-0.4031411701767174, 0.6998083783555331, 0.5896994408102707], [0.9137843705152496, 0.2727945321261955, 0.3009672198829944]], "fy": 1785.9247349332506, "fx": 1800.4946310676712, "cy": 960.0, "cx": 540.0}, {"id": 1, "img_name": "frame_00000001.jpg", "width": 1080, "height": 1920, "position": [2.949282970370609, -0.6335382228381512, 0.2673114072049478], "rotation": [[-0.0019237835554590309, 0.6668441561940055, -0.745194719792576], [-0.4011884017915847, 0.682081403631487, 0.6114023430508052], [0.915993540001645, 0.300139684430442, 0.2662179642797406]], "fy": 1785.9247349332506, "fx": 1800.4946310676712, "cy": 960.0, "cx": 540.0}, {"id": 2, "img_name": "frame_00000002.jpg", "width": 1080, "height": 1920, "position": [2.9667291513620366, -0.7387976432218415, 0.4629798781143499], "rotation": [[-0.0615571633528496, 0.6703490776602758, -0.7394882214883264], [-0.39696304628679396, 0.6633326200940405, 0.6343580810566571], [0.915768014031206, 0.33259878111501434, 0.22527093749116323]], "fy": 1785.9247349332506, "fx": 1800.4946310676712, "cy": 960.0, "cx": 540.0}, {"id": 3, "img_name": "frame_00000003.jpg", "width": 1080, "height": 1920, "position": [3.0453832762504836, -0.8551470378337114, 0.6554612477708741], "rotation": [[-0.125317384105251, 0.6663741320728737, -0.7350109314460163], [-0.38650934434140116, 0.649526706658274, 0.6547713983326062], [0.9137319519299936, 0.36614283202294223, 0.17616284108864333]], "fy": 1785.9247349332506, "fx": 1800.4946310676712, "cy": 960.0, "cx": 540.0}, {"id": 4, "img_name": "frame_00000004.jpg", "width": 1080, "height": 1920, "position": [3.0650147418053115, -0.9935019562786864, 0.8438821385157003], "rotation": [[-0.19630892974700345, 0.6585583488198234, -0.7264734718496564], [-0.3692671096993682, 0.6366912476445024, 0.6769535115996974], [0.9083526879994689, 0.40115477857380044, 0.1181957606327081]], "fy": 1785.9247349332506, "fx": 1800.4946310676712, "cy": 960.0, "cx": 540.0}]

how can i do this?

yzslab commented 3 weeks ago

I did not implement what you mentioned. You can do it based on this one: https://github.com/yzslab/gaussian-splatting-lightning/blob/f9fb16b993add3d5a8dbb396af59590ff00c23e5/utils/prune_partitions_v2.py#L24-L64

insomniaaac commented 3 weeks ago

thanks for quick reply!!! here are my simple implementation

import torch
import json
import argparse
import glob
from tqdm import tqdm
import torchvision

from internal.cameras.cameras import Cameras
from internal.renderers.vanilla_renderer import VanillaRenderer
from internal.viewer.renderer import ViewerRenderer
from internal.utils.gaussian_model_loader import GaussianModelLoader
from internal.utils.gaussian_model_editor import MultipleGaussianModelEditor
import os

def parse_cameras_json(path: str): 
     with open(path, "r") as f: 
         cameras = json.load(f) 

     c2w_list = [] 
     width_list = [] 
     height_list = [] 
     fx_list = [] 
     fy_list = [] 
     cx_list = [] 
     cy_list = [] 
     img_names = []

     for i in cameras: 
         c2w = torch.eye(4) 
         c2w[:3, :3] = torch.tensor(i["rotation"]) 
         c2w[:3, 3] = torch.tensor(i["position"]) 
         c2w_list.append(c2w) 

         width_list.append(i["width"]) 
         height_list.append(i["height"]) 
         fx_list.append(i["fx"]) 
         fy_list.append(i["fy"])
         cx_list.append(i.get("cx", i["width"] / 2)) 
         cy_list.append(i.get("cy", i["height"] / 2)) 
         img_names.append(i["img_name"])

     w2c = torch.linalg.inv(torch.stack(c2w_list)) 

     return Cameras( 
         R=w2c[..., :3, :3], 
         T=w2c[..., :3, 3], 
         fx=torch.tensor(fx_list), 
         fy=torch.tensor(fy_list), 
         cx=torch.tensor(cx_list), 
         cy=torch.tensor(cy_list), 
         width=torch.tensor(width_list, dtype=torch.int), 
         height=torch.tensor(height_list, dtype=torch.int), 
         appearance_id=torch.zeros(w2c.shape[0], dtype=torch.int), 
         normalized_appearance_id=torch.zeros(w2c.shape[0], dtype=torch.float), 
         distortion_params=torch.zeros((w2c.shape[0], 4), dtype=torch.float), 
         camera_type=torch.zeros(w2c.shape[0], dtype=torch.int), 
     ), img_names

def initializer_viewer_renderer(
        model_paths: list[str],
        enable_transform: bool,
        sh_degree: int,
        background_color,
        renderer_override,
        device,
) -> ViewerRenderer:
    model_list = []
    renderer = None

    load_device = torch.device("cuda") if len(model_paths) == 1 or enable_transform is False else torch.device("cpu")
    for model_path in model_paths:
        model, renderer = GaussianModelLoader.search_and_load(model_path, load_device)
        model.freeze()
        model_list.append(model)

    if len(model_paths) > 1:
        renderer = VanillaRenderer()
    if renderer_override is not None:
        print(f"Renderer: {renderer_override.__class__}")
        renderer = renderer_override

    model_manager = MultipleGaussianModelEditor(model_list, device)

    return ViewerRenderer(model_manager, renderer, torch.tensor(background_color, dtype=torch.float, device=device))

def save_image(image, output_path):
    torchvision.utils.save_image(image, output_path)

def render_frames(
        cameras: Cameras,
        viewer_renderer: ViewerRenderer,
        frame_output_path: str,
        image_names: list,
        device,
):
    for idx in tqdm(range(len(cameras)), desc="rendering frames"):
        # render
        camera = cameras[idx].to_device(device)
        img_name = image_names[idx]
        image = viewer_renderer.get_outputs(camera).cpu()
        save_image(image, os.path.join(frame_output_path, img_name))

if __name__ == "__main__": 
    parser = argparse.ArgumentParser()
    parser.add_argument("model_paths", type=str, nargs="+")
    parser.add_argument("--camera-json-filename", type=str, required=True)
    parser.add_argument("--output-path", type=str, required=True)
    parser.add_argument("--vanilla_gs2d", action="store_true", default=False)
    args = parser.parse_args()

    device = torch.device("cuda")

    # whether a 2DGS model
    renderer_override = None
    if args.vanilla_gs2d is True:
        from internal.renderers.vanilla_2dgs_renderer import Vanilla2DGSRenderer

        renderer_override = Vanilla2DGSRenderer()

    # instantiate renderer
    renderer = initializer_viewer_renderer(
        args.model_paths,
        enable_transform=False,
        sh_degree=3,
        background_color=torch.tensor([0.0, 0.0, 0.0], device=device),
        renderer_override=renderer_override,
        device=device,
    )

    cameras,img_names = parse_cameras_json(args.camera_json_filename)

    # create output path
    os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
    for i in glob.glob(os.path.join(args.output_path, "*.jpg")):
        os.unlink(i)

    # start rendering
    render_frames(
        cameras,
        viewer_renderer=renderer,
        device=device,
        frame_output_path=args.output_path,
        image_names=img_names,
    )