tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.16k stars 1.08k forks source link

Incompatible shapes in `make_momentum_distribution` #1640

Closed chrism0dwk closed 1 year ago

chrism0dwk commented 1 year ago

Hi Guys,

I could use a little help getting a DiagonalMassMatrixAdaptation / DualAveragingStepSizeAdaptation / PreconditionedHamiltonianMonteCarlo onion going. I begin with a StructTuple of state parts, which came out of a JointDistributionCoroutine model, and encounter an error in constructing the initial momentum distribution.

For example:

import tensorflow as tf
from tensorflow_probability.python.distribution.joint_distribution_coroutine import structural_tuple as st

foo = st.structtuple(["bar", "baz"])(tf.constant(1.0), tf.constant([1.0, 2.0, 3.0]))

Then to construct the momentum distribution for the PreconditionedHamiltonianMonteCarlo I do

import tensorflow_probability.python.experimental.mcmc.preconditioning_utils as phmc_utils
momentum = phmc_utils.make_momentum_distribution(
    foo,
    batch_shape=(0,),
)

allowing the running_variance_parts to be set by the function call. However, I get an error:

InvalidArgumentError: {{function_node __wrapped__BroadcastTo_device_/job:localhost/replica:0/task:0/device:CPU:0}} Incompatible shapes: [3] vs. [0] [Op:BroadcastTo]

Have I stepped outside the design envelope, I wonder, or is there a bug?

Cheers,

Chris

ColCarroll commented 1 year ago

Hi Chris -- The preconditioning_utils are not the most ergonomic, nor the most public, but this code almost works.

If you use batch_shape=tf.TensorShape([]), it "works on my colab". The batch_shape here can be set to the tf.shape of the log_prob of your model at the initial point(s).

Note that in general the MCMC stack expects lists of state parts, so momentum will not be a structure-aware object:

>>> momentum.sample()
[<tf.Tensor: shape=(), dtype=float32, numpy=-2.8567312>,
 <tf.Tensor: shape=(3,), dtype=float32, numpy=array([ 0.55520535, -1.6700784 ,  1.2413677 ], dtype=float32)>]
chrism0dwk commented 1 year ago

That's great, thanks again and sorry for the disgraceful typo above (now edited!). I think I've had a couple of times today where TF2.10.0 has explicitly demanded a TensorShape when previously a regular Tensor (or even a Python list) would do, so maybe the semantics are tightening up.

Chris