brentyi / jaxlie

Rigid transforms + Lie groups in JAX
MIT License
234 stars 15 forks source link

Gradient of SO3.log() gives NaN when w=0 #9

Closed Ending2015a closed 1 year ago

Ending2015a commented 1 year ago

Hello, thank you very much for this amazing library.

I found that there is an NaN issue occurred in line 381 when calculating the gradient of SO3.log(). https://github.com/brentyi/jaxlie/blob/ad93513b96e8e852af7862ee3c1e96a0d1dfd552/jaxlie/_so3.py#L379-L387

The following is a small example to cause NaN:

a = jnp.array([jnp.pi, 0, 0], dtype=jnp.float32)
def func(x):
  return SO3.exp(x).log().sum()
print(jax.grad(func)(a))  # ===> [nan nan nan]

I think the reason might be that jnp.where actually does not block any unsafe gradients (e.g. x/0) as described in the official FAQ. And this also appears in line 381 when the rotation angle approaches pi and -pi, the w in line 381 will be 0 and thus cause the bad gradient. To fix this issue, I suggest adding a safe_w=1.0 if use_taylor is False before calculating the atan_factor:

safe_w = jnp.where(use_taylor, w, 1.0)
atan_factor = jnp.where(
    use_taylor,
    2.0 / safe_w - 2.0 / 3.0 * norm_sq / safe_w**3,
    jnp.where(
        jnp.abs(w) < get_epsilon(w.dtype),
        jnp.where(w > 0, 1.0, -1.0) * jnp.pi / norm_safe,
        2.0 * atan_n_over_w / norm_safe,
    ),
)
brentyi commented 1 year ago

Thanks!

Yeah, that looks like a clear oversight on my end. If you make a PR I'd be happy to merge it, otherwise I can make the fix+add to tests later this week.

Ending2015a commented 1 year ago

@brentyi I have opened a PR for it. Thank you.