WangFeng18 / 3d-gaussian-splatting

Implementation for 3d gaussian splatting
MIT License
324 stars 19 forks source link

Faster scaling factor calculation and requirements.txt #3

Closed vaibhavnayel closed 1 year ago

vaibhavnayel commented 1 year ago

This PR adds:

Problem

The original code implements a for loop to calculate the nearest neighbours among the sparse cloud. While the distance calculation happens on the GPU, doing this per point is inefficient and leads to larger and larger overheads as point count increases.

Fix

This is solved by using a KD tree to run queries on the sparse cloud, allowing larger scenes to be loaded in a reasonable amount of time. The query is performed as a batch operation instead of one by one.

I profiled the 2 implementations and you can see the KD tree is a lot faster. Image loading now becomes the rate limiting step in data preprocessing (which can also be performed in parallel, but not implemented by this PR) timing

Here is the profiling code:

from pykdtree.kdtree import KDTree
import torch
from tqdm import trange
import time
import matplotlib.pyplot as plt

timing_kdtree, timing_loop = [], []
num_pts = [1000, 10_000, 50_000, 100_000, 500_000]
for n in num_pts:
    _pos = torch.rand(n, 3).to(torch.float32).to('cuda')

    start = time.time()
    _pos_np = _pos.cpu().numpy()
    kd_tree = KDTree(_pos_np)
    dist, idx = kd_tree.query(_pos_np, k=4)
    mean_min_three_dis_kd = dist[:, 1:].mean(axis=1)
    timing_kdtree.append(time.time() - start)

    start = time.time()
    mean_min_three_dis = []
    for i_pos in trange(_pos.shape[0]):
        _r = (_pos[i_pos:i_pos+1] - _pos).norm(dim=-1).sort(dim=-1)[0][1:4].mean().item()
        mean_min_three_dis.append(_r)
    mean_min_three_dis_loop = torch.Tensor(mean_min_three_dis).to(torch.float32)
    timing_loop.append(time.time() - start)

plt.plot(num_pts, timing_kdtree, label='kdtree', marker='o')
plt.plot(num_pts, timing_loop, label='loop', marker='x')
plt.xscale('log')
plt.yscale('log')
plt.xlabel('Number of points')
plt.ylabel('Time (s)')
plt.legend()
plt.show()

Hope this helps with faster experimentation!

WangFeng18 commented 1 year ago

Thanks for the KD Tree insights, that really helps the efficiency of calculating distances for adjacent points.