Open junpenglao opened 4 years ago
The bug seems to be introduced in https://github.com/tensorflow/tensorflow/commit/c2e594440e1d9839546b93a93d8646b06891d7de# it is discovered in https://github.com/pymc-devs/pymc4/issues/317#issuecomment-683416761.
@davmre I am assigning this to you since you have a big more context about the change.
Thanks for reporting; this is a pretty weird issue. After poking for a bit I think the root of the problem is that this snippet computing the gradient of tf.gather
:
@tf.function(autograph=False)
def gather_grad(x):
with tf.GradientTape() as tape:
tape.watch(x)
v = tf.gather(x, 0)
g = tape.gradient(v, x)
return g
gather_grad(x=tf.convert_to_tensor([1.]))
returns a
<google3.third_party.tensorflow.python.framework.indexed_slices.IndexedSlices at 0x7fb5f27bc4a8>
instance instead of a simple Tensor
. The IndexedSlices
instance is convertible to a Tensor, but its underlying representation uses two Tensors (one for a value being sliced, the other for the slice), and that screws up the HamiltonianMonteCarlo while_loop which expects to see the same Tensor structure it was initialized with.
The contribution of tensorflow/tensorflow@c2e5944 is somewhat tangential: it calls tf.gather(x, 0)
for unit-batch Tensors directly, where previously the autovectorization machinery would see tf.gather(x, i)
(where i is an abstract batch index) and do something more complicated that I think might end up eliding the gather altogether. The change is fine IMHO, but it seems to have triggered this complicated interaction.
I think we'll need to consult TF Core team on the most natural fix: it might make sense to change the gradient definition for tf.gather
, or for while_loop
to try to convert any CompositeTensors in its loop state to Tensors before giving up. I'll file a couple of bugs.
Actually it might make more sense to just work around this at the TFP level by calling convert_to_tensor
on all gradients inside the MCMC loop. I'll follow up tomorrow.
Minimal reproducible example:
returns: