tensorflow / graphics

TensorFlow Graphics: Differentiable Graphics Layers for TensorFlow
Apache License 2.0
2.75k stars 366 forks source link

quaternion_from_rotation_matrix NaN or +Inf errors #462

Open areiner222 opened 4 years ago

areiner222 commented 4 years ago

I am trying to convert a 6D continuous representation of a rotation matrix into an axis-angle representation in a network. During backprop, I can occasionally encounter NaNs / +Inf. I was able to determine the source is quaternion_from_rotation_matrix by using a combination of tf.debugging.enable_check_numerics()and the tensorflow_graphics debugging mode.

Sometimes the tfg debugger catches an improperly normalized rotation matrix (i.e., assert_rotation_matrix_normalized), but other times it does not and quaternion_from_rotation_matrix errors in one of the safe_ops.safe_unsigned_div ops or when computing the sqrt of the trace where the trace < -1 by a very small delta.

In converting the continuous representation to a rotation matrix, I'm using an implementation very similar to this one.

Do you have any idea if this is a bug or if I might be encountering some kind of numerical instability? I am working on a repro but having some trouble freezing the model state when the error occurs to extract the bad rotation matrix.

cem-keskin commented 4 years ago

quaternion.from_rotation_matrix uses tf.where to pick a method to populate the quaternion - and tf.where (and tf.cond) executes all branches regardless of the condition. This means that sqrt(trace + 1) may fail because of numerical instability only in debug mode where we check for NaN values. But tf.where will not pick the NaN value since the required condition to pick that branch is trace > 0, which is not the case here.

What seems to be happening is that your rotation matrices are not properly normalized before calling this function - and this function assumes that the matrices are normalized (and only checks if they aren't in debug mode). So, I'd suggest explicitly orthogonalizing your unnormalized rotation matrices before calling this function.

prafael18 commented 3 years ago

@cem-keskin I think that using tf.where is more susceptible to issues than tf.cond. I just encountered a similar issue where NaNs were being backpropagated when the trace of the rotation matrix was -1. I think this stems from the fact that tf.where backpropagates 0s for the branches that should be ignored, but since a NaN is computed along the way, instead of backpropagating 0s, NaNs are backpropagated until the inputs. I was able to fix the issue by replacing tf.where with tf.cond.

I'm not sure what you meant when you said that both tf.where and tf.cond execute all branches regardless of condition. The documentation seems to indicate otherwise.

If x < y, the tf.add operation will be executed and tf.square operation will not be executed.

Here is a minimal reproducible example of the issue, where I'm using tensorflow==2.3.0 and tensorflow-graphics==2020.5.20:

import tensorflow as tf
import tensorflow_graphics.geometry.transformation as tfg

so3 = tf.Variable(
    [[[0.873557, 0.42894128, -0.23001657], [0.42886758, -0.901813, -0.05297272],
      [-0.23015413, -0.05237198, -0.971744]]],
    dtype=tf.float32)
with tf.GradientTape() as tape:
  quat = tfg.quaternion.from_rotation_matrix(so3)
print(tape.gradient(quat, so3))

which outputs

tf.Tensor(
[[[       nan  0.2582983  0.2582983]
  [ 0.2582983        nan -0.2582983]
  [ 0.2582983  0.2582983        nan]]], shape=(1, 3, 3), dtype=float32)

What seemed to fix the issue for me (at least for this example) was to switch the tf.where ops to tf.cond ops at the end of quaternion.from_rotation_matrix. Something like:

where_2 = tf.cond(entries[1][1] > entries[2][2], cond_2, cond_3)         
where_1 = tf.cond((entries[0][0] > entries[1][1]) & (entries[0][0] > entries[2][2]), cond_1, lambda: where_2)                                             
quat = tf.cond(trace > 0, tr_positive, lambda: where_1) 

With these changes, the example above outputs

tf.Tensor(
[[[ 0.11542333  0.2582983  0.2582983]
  [ 0.2582983  -0.11542333  -0.2582983]
  [ 0.2582983  0.2582983  -0.11542333]]], shape=(1, 3, 3), dtype=float32)

which is much more reasonable for the gradients.

cem-keskin commented 3 years ago

@prafael18. I was referring to the the warning in the documentation: "Warning: Any Tensors or Operations created outside of true_fn and false_fn will be executed regardless of which branch is selected at runtime.". But you are right that tf.cond should actually work in this case, because the problematic tensors are created inside true_fn and false_fn. Thanks for confirming it.

cem-keskin commented 3 years ago

Of course tf.cond would only work for a single tensor and not a batch, unfortunately. That's why we ended up using tf.where. Using tf.dynamic_partition and tf.dynamic_stitch is also an option.

prafael18 commented 3 years ago

I'm not sure what are the performance implications of my solution, but I ended up just reshaping the batch dimensions into a single one, applying tf.map_fn to each element in the batch, and restoring the orignal batch shape.

SuwoongHeo commented 3 years ago

This issue still occurs. By referring suggestion here (https://github.com/tensorflow/tensorflow/issues/38349#issuecomment-836154973) I could have managed to solve the. Specifically, I put sq=tf.maximum(~~, eps_addition) for every place where sq is computed in from_rotation_matrix. However, I`m not sure whether it is the correct way to compute the right gradient value.