WU-CVGL / BAD-NeRF

[CVPR 2023] 😈BAD-NeRF: Bundle Adjusted Deblur Neural Radiance Fields
https://wangpeng000.github.io/BAD-NeRF/
MIT License
185 stars 13 forks source link

Fix NaN value in Spline.log_q2r_parallel(q) #5

Closed pianwan closed 1 year ago

pianwan commented 1 year ago

I made two changes to log_q2r_parallel(q).

  1. Use criterion = w / torch.abs(w + 1e-20) with + 1e-20 to avoid nan value.
    def log_q2r_taylor_w(w, theta):
    criterion = w / torch.abs(w + 1e-20)
    return criterion * torch.pi / theta
  2. Use torch.where() to replace values where theta < eps_theta and w < eps_w & theta >= eps_theta.

    def log_q2r_parallel(q, eps_theta=1e-20, eps_w=1e-10):
    x, y, z, w = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
    
    theta = torch.sqrt(x ** 2 + y ** 2 + z ** 2)
    
    bool_criterion_theta = (theta < eps_theta)
    bool_criterion_w = ((theta >= eps_theta) & (torch.abs(w) < eps_w))
    
    taylor_theta = log_q2r_taylor_theta(w, theta)
    taylor_w = log_q2r_taylor_w(w, theta)
    normal = log_q2r(w, theta)
    lambda_ = torch.where(bool_criterion_theta, taylor_theta,
                          torch.where(bool_criterion_w, taylor_w, normal))
    
    r_ = torch.stack([lambda_ * x, lambda_ * y, lambda_ * z], -1)
    
    return r_

There will be seemingly no NaN value.