tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.23k stars 1.09k forks source link

Recent change in tf.vectorized_map breaks MCMC when batch_shape = 1 #1071

Open junpenglao opened 4 years ago

junpenglao commented 4 years ago

Minimal reproducible example:

import tensorflow.compat.v2 as tf
import tensorflow_probability as tfp

tfd = tfp.distributions

dist = tfd.Normal(0., 1.)

def vectorized_logpfn(*state): 
    return tf.vectorized_map(lambda mini_state: dist.log_prob(*mini_state), state)

init = dist.sample(1)

@tf.function
def run_fn(init, burn_in):
    return tfp.mcmc.sample_chain(
        10, init, 
        num_burnin_steps=burn_in,
        kernel=tfp.mcmc.HamiltonianMonteCarlo(
            vectorized_logpfn, .1, num_leapfrog_steps=5))

run_fn(init, 10)

returns:

ValueError                                Traceback (most recent call last)
<ipython-input-17-28b7450607cd> in <module>
      7             vectorized_logpfn, .1, num_leapfrog_steps=5))
      8 
----> 9 run_fn(init, 10)

~/miniconda3/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
    784     tracing_count = self._get_tracing_count()
    785     with trace.Trace(self._name) as tm:
--> 786       result = self._call(*args, **kwds)
    787       compiler = "xla" if self._experimental_compile else "nonXla"
    788       new_tracing_count = self._get_tracing_count()

~/miniconda3/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
    827       # This is the first call of __call__, so we have to initialize.
    828       initializers = []
--> 829       self._initialize(args, kwds, add_initializers_to=initializers)
    830     finally:
    831       # At this point we know that the initialization is complete (or less

~/miniconda3/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
    715     self._concrete_stateful_fn = (
    716         self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
--> 717             *args, **kwds))
    718 
    719     def invalid_creator_scope(*unused_args, **unused_kwds):

~/miniconda3/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
   2953       args, kwargs = None, None
   2954     with self._lock:
-> 2955       graph_function, _ = self._maybe_define_function(args, kwargs)
   2956     return graph_function
   2957 

~/miniconda3/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   3353 
   3354           self._function_cache.missed.add(call_context_key)
-> 3355           graph_function = self._create_graph_function(args, kwargs)
   3356           self._function_cache.primary[cache_key] = graph_function
   3357 

~/miniconda3/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   3198             arg_names=arg_names,
   3199             override_flat_arg_shapes=override_flat_arg_shapes,
-> 3200             capture_by_value=self._capture_by_value),
   3201         self._function_attributes,
   3202         function_spec=self.function_spec,

~/miniconda3/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
    985         _, original_func = tf_decorator.unwrap(python_func)
    986 
--> 987       func_outputs = python_func(*func_args, **func_kwargs)
    988 
    989       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

~/miniconda3/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
    623             xla_context.Exit()
    624         else:
--> 625           out = weak_wrapped_fn().__wrapped__(*args, **kwds)
    626         return out
    627 

~/miniconda3/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
    972           except Exception as e:  # pylint:disable=broad-except
    973             if hasattr(e, "ag_error_metadata"):
--> 974               raise e.ag_error_metadata.to_exception(e)
    975             else:
    976               raise

ValueError: in user code:

    <ipython-input-17-28b7450607cd>:4 run_fn  *
        10, init,
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow_probability/python/mcmc/sample.py:374 sample_chain  **
        parallel_iterations=parallel_iterations)
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow_probability/python/mcmc/internal/util.py:464 trace_scan
        parallel_iterations=parallel_iterations)
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py:574 new_func
        return func(*args, **kwargs)
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow/python/ops/control_flow_ops.py:2499 while_loop_v2
        return_same_structure=True)
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow/python/ops/control_flow_ops.py:2696 while_loop
        back_prop=back_prop)
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow/python/ops/while_v2.py:196 while_loop
        add_control_dependencies=add_control_dependencies)
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py:987 func_graph_from_py_func
        func_outputs = python_func(*func_args, **func_kwargs)
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow/python/ops/while_v2.py:174 wrapped_body
        outputs = body(*_pack_sequence_as(orig_loop_vars, args))
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow_probability/python/mcmc/internal/util.py:450 _body
        state = loop_fn(state, elem)
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow_probability/python/mcmc/sample.py:358 _trace_scan_fn
        parallel_iterations=parallel_iterations)
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow_probability/python/mcmc/internal/util.py:353 smart_for_loop
        parallel_iterations=parallel_iterations
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py:574 new_func
        return func(*args, **kwargs)
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow/python/ops/control_flow_ops.py:2499 while_loop_v2
        return_same_structure=True)
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow/python/ops/control_flow_ops.py:2696 while_loop
        back_prop=back_prop)
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow/python/ops/while_v2.py:196 while_loop
        add_control_dependencies=add_control_dependencies)
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py:987 func_graph_from_py_func
        func_outputs = python_func(*func_args, **func_kwargs)
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow/python/ops/while_v2.py:174 wrapped_body
        outputs = body(*_pack_sequence_as(orig_loop_vars, args))
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow_probability/python/mcmc/internal/util.py:351 <lambda>
        body=lambda i, *args: [i + 1] + list(body_fn(*args)),
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow_probability/python/mcmc/sample.py:351 _seeded_one_step
        kernel.one_step(*state_and_results, **one_step_kwargs))
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow_probability/python/mcmc/hmc.py:574 one_step
        current_state, previous_kernel_results, seed=seed)
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow_probability/python/mcmc/metropolis_hastings.py:218 one_step
        **inner_kwargs)
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow_probability/python/mcmc/hmc.py:777 one_step
        current_target_log_prob_grad_parts)
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow_probability/python/mcmc/internal/leapfrog_integrator.py:291 __call__
        target_grad_parts,
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py:574 new_func
        return func(*args, **kwargs)
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow/python/ops/control_flow_ops.py:2499 while_loop_v2
        return_same_structure=True)
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow/python/ops/control_flow_ops.py:2696 while_loop
        back_prop=back_prop)
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow/python/ops/while_v2.py:196 while_loop
        add_control_dependencies=add_control_dependencies)
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py:987 func_graph_from_py_func
        func_outputs = python_func(*func_args, **func_kwargs)
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow/python/ops/while_v2.py:180 wrapped_body
        expand_composites=True)
    /home/junpenglao/miniconda3/lib/python3.7/site-packages/tensorflow/python/util/nest.py:411 assert_same_structure
        % (str(e), str1, str2))

    ValueError: The two structures don't have the same nested structure.

    First structure: type=list str=[<tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/add:0' shape=() dtype=int32>, [<tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/leapfrog_integrate_one_step/add_1:0' shape=(1,) dtype=float32>], [<tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/leapfrog_integrate_one_step/add:0' shape=(1,) dtype=float32>], <tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/pfor/Tile:0' shape=(1,) dtype=float32>, [<tensorflow.python.framework.indexed_slices.IndexedSlices object at 0x7f37d3f45e90>]]

    Second structure: type=list str=[<tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/iter:0' shape=() dtype=int32>, [<tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/add:0' shape=(1,) dtype=float32>], [<tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/Placeholder_2:0' shape=(1,) dtype=float32>], <tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/Placeholder_4:0' shape=(1,) dtype=float32>, [<tf.Tensor 'mcmc_sample_chain/trace_scan/while/smart_for_loop/while/Placeholder_5:0' shape=(1,) dtype=float32>]]

    More specifically: Substructure "type=IndexedSlices str=IndexedSlices(indices=Tensor("mcmc_sample_chain/trace_scan/while/smart_for_loop/while/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/value_and_gradient/gradients/mcmc_sample_chain/trace_scan/while/smart_for_loop/while/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/value_and_gradient/loop_body/GatherV2_grad/Reshape_1:0", shape=(1,), dtype=int32), values=Tensor("mcmc_sample_chain/trace_scan/while/smart_for_loop/while/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/value_and_gradient/gradients/mcmc_sample_chain/trace_scan/while/smart_for_loop/while/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/value_and_gradient/loop_body/GatherV2_grad/Reshape:0", shape=(1,), dtype=float32), dense_shape=Tensor("mcmc_sample_chain/trace_scan/while/smart_for_loop/while/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/value_and_gradient/gradients/mcmc_sample_chain/trace_scan/while/smart_for_loop/while/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/value_and_gradient/loop_body/GatherV2_grad/Cast:0", shape=(1,), dtype=int32))" is a sequence, while substructure "type=Tensor str=Tensor("mcmc_sample_chain/trace_scan/while/smart_for_loop/while/Placeholder_5:0", shape=(1,), dtype=float32)" is not
    Entire first structure:
    [., [.], [.], ., [.]]
    Entire second structure:
    [., [.], [.], ., [.]]
junpenglao commented 4 years ago

The bug seems to be introduced in https://github.com/tensorflow/tensorflow/commit/c2e594440e1d9839546b93a93d8646b06891d7de# it is discovered in https://github.com/pymc-devs/pymc4/issues/317#issuecomment-683416761.

@davmre I am assigning this to you since you have a big more context about the change.

davmre commented 4 years ago

Thanks for reporting; this is a pretty weird issue. After poking for a bit I think the root of the problem is that this snippet computing the gradient of tf.gather:

@tf.function(autograph=False)
def gather_grad(x):
  with tf.GradientTape() as tape:
    tape.watch(x)
    v = tf.gather(x, 0)
  g = tape.gradient(v, x)
  return g
gather_grad(x=tf.convert_to_tensor([1.]))

returns a <google3.third_party.tensorflow.python.framework.indexed_slices.IndexedSlices at 0x7fb5f27bc4a8> instance instead of a simple Tensor. The IndexedSlices instance is convertible to a Tensor, but its underlying representation uses two Tensors (one for a value being sliced, the other for the slice), and that screws up the HamiltonianMonteCarlo while_loop which expects to see the same Tensor structure it was initialized with.

The contribution of tensorflow/tensorflow@c2e5944 is somewhat tangential: it calls tf.gather(x, 0) for unit-batch Tensors directly, where previously the autovectorization machinery would see tf.gather(x, i) (where i is an abstract batch index) and do something more complicated that I think might end up eliding the gather altogether. The change is fine IMHO, but it seems to have triggered this complicated interaction.

I think we'll need to consult TF Core team on the most natural fix: it might make sense to change the gradient definition for tf.gather, or for while_loop to try to convert any CompositeTensors in its loop state to Tensors before giving up. I'll file a couple of bugs.

davmre commented 4 years ago

Actually it might make more sense to just work around this at the TFP level by calling convert_to_tensor on all gradients inside the MCMC loop. I'll follow up tomorrow.