Closed pianwan closed 1 year ago
I made two changes to log_q2r_parallel(q).
criterion = w / torch.abs(w + 1e-20)
+ 1e-20
def log_q2r_taylor_w(w, theta): criterion = w / torch.abs(w + 1e-20) return criterion * torch.pi / theta
Use torch.where() to replace values where theta < eps_theta and w < eps_w & theta >= eps_theta.
torch.where()
theta < eps_theta
w < eps_w & theta >= eps_theta
def log_q2r_parallel(q, eps_theta=1e-20, eps_w=1e-10): x, y, z, w = q[..., 0], q[..., 1], q[..., 2], q[..., 3] theta = torch.sqrt(x ** 2 + y ** 2 + z ** 2) bool_criterion_theta = (theta < eps_theta) bool_criterion_w = ((theta >= eps_theta) & (torch.abs(w) < eps_w)) taylor_theta = log_q2r_taylor_theta(w, theta) taylor_w = log_q2r_taylor_w(w, theta) normal = log_q2r(w, theta) lambda_ = torch.where(bool_criterion_theta, taylor_theta, torch.where(bool_criterion_w, taylor_w, normal)) r_ = torch.stack([lambda_ * x, lambda_ * y, lambda_ * z], -1) return r_
There will be seemingly no NaN value.
I made two changes to log_q2r_parallel(q).
criterion = w / torch.abs(w + 1e-20)
with+ 1e-20
to avoid nan value.Use
torch.where()
to replace values wheretheta < eps_theta
andw < eps_w & theta >= eps_theta
.There will be seemingly no NaN value.