Closed insomniaaac closed 3 weeks ago
I did not implement what you mentioned. You can do it based on this one: https://github.com/yzslab/gaussian-splatting-lightning/blob/f9fb16b993add3d5a8dbb396af59590ff00c23e5/utils/prune_partitions_v2.py#L24-L64
thanks for quick reply!!! here are my simple implementation
import torch
import json
import argparse
import glob
from tqdm import tqdm
import torchvision
from internal.cameras.cameras import Cameras
from internal.renderers.vanilla_renderer import VanillaRenderer
from internal.viewer.renderer import ViewerRenderer
from internal.utils.gaussian_model_loader import GaussianModelLoader
from internal.utils.gaussian_model_editor import MultipleGaussianModelEditor
import os
def parse_cameras_json(path: str):
with open(path, "r") as f:
cameras = json.load(f)
c2w_list = []
width_list = []
height_list = []
fx_list = []
fy_list = []
cx_list = []
cy_list = []
img_names = []
for i in cameras:
c2w = torch.eye(4)
c2w[:3, :3] = torch.tensor(i["rotation"])
c2w[:3, 3] = torch.tensor(i["position"])
c2w_list.append(c2w)
width_list.append(i["width"])
height_list.append(i["height"])
fx_list.append(i["fx"])
fy_list.append(i["fy"])
cx_list.append(i.get("cx", i["width"] / 2))
cy_list.append(i.get("cy", i["height"] / 2))
img_names.append(i["img_name"])
w2c = torch.linalg.inv(torch.stack(c2w_list))
return Cameras(
R=w2c[..., :3, :3],
T=w2c[..., :3, 3],
fx=torch.tensor(fx_list),
fy=torch.tensor(fy_list),
cx=torch.tensor(cx_list),
cy=torch.tensor(cy_list),
width=torch.tensor(width_list, dtype=torch.int),
height=torch.tensor(height_list, dtype=torch.int),
appearance_id=torch.zeros(w2c.shape[0], dtype=torch.int),
normalized_appearance_id=torch.zeros(w2c.shape[0], dtype=torch.float),
distortion_params=torch.zeros((w2c.shape[0], 4), dtype=torch.float),
camera_type=torch.zeros(w2c.shape[0], dtype=torch.int),
), img_names
def initializer_viewer_renderer(
model_paths: list[str],
enable_transform: bool,
sh_degree: int,
background_color,
renderer_override,
device,
) -> ViewerRenderer:
model_list = []
renderer = None
load_device = torch.device("cuda") if len(model_paths) == 1 or enable_transform is False else torch.device("cpu")
for model_path in model_paths:
model, renderer = GaussianModelLoader.search_and_load(model_path, load_device)
model.freeze()
model_list.append(model)
if len(model_paths) > 1:
renderer = VanillaRenderer()
if renderer_override is not None:
print(f"Renderer: {renderer_override.__class__}")
renderer = renderer_override
model_manager = MultipleGaussianModelEditor(model_list, device)
return ViewerRenderer(model_manager, renderer, torch.tensor(background_color, dtype=torch.float, device=device))
def save_image(image, output_path):
torchvision.utils.save_image(image, output_path)
def render_frames(
cameras: Cameras,
viewer_renderer: ViewerRenderer,
frame_output_path: str,
image_names: list,
device,
):
for idx in tqdm(range(len(cameras)), desc="rendering frames"):
# render
camera = cameras[idx].to_device(device)
img_name = image_names[idx]
image = viewer_renderer.get_outputs(camera).cpu()
save_image(image, os.path.join(frame_output_path, img_name))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("model_paths", type=str, nargs="+")
parser.add_argument("--camera-json-filename", type=str, required=True)
parser.add_argument("--output-path", type=str, required=True)
parser.add_argument("--vanilla_gs2d", action="store_true", default=False)
args = parser.parse_args()
device = torch.device("cuda")
# whether a 2DGS model
renderer_override = None
if args.vanilla_gs2d is True:
from internal.renderers.vanilla_2dgs_renderer import Vanilla2DGSRenderer
renderer_override = Vanilla2DGSRenderer()
# instantiate renderer
renderer = initializer_viewer_renderer(
args.model_paths,
enable_transform=False,
sh_degree=3,
background_color=torch.tensor([0.0, 0.0, 0.0], device=device),
renderer_override=renderer_override,
device=device,
)
cameras,img_names = parse_cameras_json(args.camera_json_filename)
# create output path
os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
for i in glob.glob(os.path.join(args.output_path, "*.jpg")):
os.unlink(i)
# start rendering
render_frames(
cameras,
viewer_renderer=renderer,
device=device,
frame_output_path=args.output_path,
image_names=img_names,
)
thanks for maintaining such a good framework!!!
after training, i got some ckpts the next step i want to render images from given cameras.json file
how can i do this?