Kai-46 / nerfplusplus

improves over nerf in 360 capture of unbounded scenes
BSD 2-Clause "Simplified" License
911 stars 101 forks source link

Explanation of intersect_sphere and a faster implementation #11

Closed kwea123 closed 3 years ago

kwea123 commented 3 years ago

This function computes the intersection depth, but there is no explanation either in the paper or in the code. https://github.com/Kai-46/nerfplusplus/blob/57920483846bbab7708f9d30f797f42962b6d6e1/ddp_train_nerf.py#L42-L57

So in case it's not clear for somebody, I intend to provide some insights of how it is calculated, and a faster implementation based on my approach: We have the origin o and the direction d, and we want the intersection depth with the unit sphere. A straightforward method is to find t such that ||o+td|| = 1. By raising both sides to the square, what we get is a quadratic equation in t such that:

||d||^2*t^2 + 2*(o.d)*t + ||o||^2-1 = 0

Then we can solve t using the famous formula. \ It results in the following implementation:

def intersect_sphere(rays_o, rays_d):
    odotd = torch.sum(rays_o*rays_d, 1)
    d_norm_sq = torch.sum(rays_d**2, 1)
    o_norm_sq = torch.sum(rays_o**2, 1)
    determinant = odotd**2+(1-o_norm_sq)*d_norm_sq
    assert torch.all(determinant>=0), \
        'Not all your cameras are bounded by the unit sphere; please make sure the cameras are normalized properly!'
    return (torch.sqrt(determinant)-odotd)/d_norm_sq

which I have verified to yield the same result (epsilon-close) as the original implementation, but 5-10x faster (11ms vs 2ms for 100k rays on my PC, not that significant though).

Another possible code optimization that we can do is possibly normalize rays_d from the beginning, that way we can get rid of the d_norm_sq in intersect_sphere and also here https://github.com/Kai-46/nerfplusplus/blob/57920483846bbab7708f9d30f797f42962b6d6e1/ddp_model.py#L82-L83

kwea123 commented 3 years ago

Just noticed that the same concept can be used to compute depth2pts_outside, just by replacing the norm 1 with norm r, ie. find t such that ||o+td||=r. This implementation is consistently faster than the original one, and probably easier to understand.