a = jnp.array([jnp.pi, 0, 0], dtype=jnp.float32)
def func(x):
return SO3.exp(x).log().sum()
print(jax.grad(func)(a)) # ===> [nan nan nan]
I think the reason might be that jnp.where actually does not block any unsafe gradients (e.g. x/0) as described in the official FAQ. And this also appears in line 381 when the rotation angle approaches pi and -pi, the w in line 381 will be 0 and thus cause the bad gradient. To fix this issue, I suggest adding a safe_w=1.0 if use_taylor is False before calculating the atan_factor:
Yeah, that looks like a clear oversight on my end. If you make a PR I'd be happy to merge it, otherwise I can make the fix+add to tests later this week.
Hello, thank you very much for this amazing library.
I found that there is an NaN issue occurred in line 381 when calculating the gradient of SO3.log(). https://github.com/brentyi/jaxlie/blob/ad93513b96e8e852af7862ee3c1e96a0d1dfd552/jaxlie/_so3.py#L379-L387
The following is a small example to cause NaN:
I think the reason might be that
jnp.where
actually does not block any unsafe gradients (e.g. x/0) as described in the official FAQ. And this also appears in line 381 when the rotation angle approachespi
and-pi
, thew
in line 381 will be0
and thus cause the bad gradient. To fix this issue, I suggest adding asafe_w=1.0
ifuse_taylor
isFalse
before calculating theatan_factor
: