tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.24k stars 1.09k forks source link

log_prob issue when concatenating distributions across batch_shape #1432

Open lucasmiranda42 opened 3 years ago

lucasmiranda42 commented 3 years ago

Dear all,

I was wondering if there is any way of batching already built distributions along a certain batch_shape axis. For example:

A = tfd.Normal(np.random.normal(size=[10, 2, 4]), np.random.normal(size=[10, 2, 4]))
B = tfd.Normal(np.random.normal(size=[10, 4, 4]), np.random.normal(size=[10, 4, 4]))

# print(A); print(B)
tfp.distributions.Normal("Normal", batch_shape=[10, 2, 4], event_shape=[], dtype=float64)
tfp.distributions.Normal("Normal", batch_shape=[10, 4, 4], event_shape=[], dtype=float64)

# Concatenate A and B and get ->
<tfp.distributions.Normal 'Normal' batch_shape=[10, 6, 4] event_shape=[] dtype=float64>

I tried using tfd.Blockwise, but the 'concatenated' axis ends up being part of the event_shape. On a similar note, is there any way of 'converting' an event_shape dimension to a batch_shape one? (something like an inverse of tfd.Independent).

l = [A[:, i] for i in range(A.batch_shape[1])] + [B[:, i] for i in range(B.batch_shape[1])]    
tfd.Blockwise(l)

<tfp.distributions.Blockwise 'Blockwise' batch_shape=[10, 4] event_shape=[6] dtype=float64>

Thank you very much! Lucas

brianwa84 commented 3 years ago

We have this, but it is apparently not in the public interface: https://cs.opensource.google/tensorflow/probability/+/main:tensorflow_probability/python/distributions/batch_concat.py Probably a small PR to add it into distributions/init.py

On Tue, Sep 14, 2021, 7:08 AM Lucas Miranda @.***> wrote:

Dear all,

I was wondering if there is any way of batching already built distributions along a certain batch_shape axis. For example:

A = tfd.Normal(np.random.normal(size=[10, 2, 4]), np.random.normal(size=[10, 2, 4])) B = tfd.Normal(np.random.normal(size=[10, 4, 4]), np.random.normal(size=[10, 4, 4]))

print(A); print(B)

tfp.distributions.Normal("Normal", batch_shape=[10, 2, 4], event_shape=[], dtype=float64) tfp.distributions.Normal("Normal", batch_shape=[10, 4, 4], event_shape=[], dtype=float64)

Concatenate A and B and get ->

<tfp.distributions.Normal 'Normal' batch_shape=[10, 6, 4] event_shape=[] dtype=float64>

I tried using tfd.Blockwise, but the 'concatenated' axis ends up being part of the event_shape. On a similar note, is there any way of 'converting' an event_shape dimension to a batch_shape one? (something like an inverse of tfd.Independent).

l = [A[:, i] for i in range(A.batch_shape[1])] + [B[:, i] for i in range(B.batch_shape[1])] tfd.Blockwise(l)

<tfp.distributions.Blockwise 'Blockwise' batch_shape=[10, 4] event_shape=[6] dtype=float64>

Thank you very much! Lucas

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/tensorflow/probability/issues/1432, or unsubscribe https://github.com/notifications/unsubscribe-auth/AFJFSI7JZTSKV2ATQQJBHQTUB4UKJANCNFSM5D74OCIA . Triage notifications on the go with GitHub Mobile for iOS https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675 or Android https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub.

lucasmiranda42 commented 3 years ago

Oh, that's wonderful. Thank you! Shall I proceed with the PR?

lucasmiranda42 commented 3 years ago

One more (related) question. I tried tf.BatchConcat and it seems to work, but I get an error when attempting to compute log_prob. What am I doing wrong? Thanks!

First I instantiate a list of two distributions (I append to a pre-existing list because the second element depends on a condition on the original code):

emissions_per_feature = [
        tfd.MixtureSameFamily(
            mixture_distribution=tfd.Categorical(logits=gm_logits),
            components_distribution=tfd.Normal(
                loc=trainable_means, scale=trainable_scales, allow_nan_stats=True
            ),
        )
    ]

emissions_per_feature.append(
            tfd.MixtureSameFamily(
                mixture_distribution=tfd.Categorical(logits=bm_logits),
                components_distribution=tfd.Beta(
                    concentration1=trainable_conc_1s,
                    concentration0=trainable_conc_0s,
                    force_probs_to_zero_outside_support=True,
                ),
            )
        )

=> [<tfp.distributions.MixtureSameFamily 'MixtureSameFamily' batch_shape=[12, 4] event_shape=[] dtype=float32>, 
       <tfp.distributions.MixtureSameFamily 'MixtureSameFamily' batch_shape=[12, 2] event_shape=[] dtype=float32>]

Then I apply BatchConcat and Blockwise:

emission_distribution = tfd.Blockwise(tfd.BatchConcat(emissions_per_feature, axis=1),
                                          name="emission_distribution"
    )

=> tfp.distributions.Blockwise("emission_distribution", batch_shape=[12, 6], event_shape=[1], dtype=float32)

This looks exactly like what I had in mind, but when I attempt to compute a log_prob, I get the following error:

tf_model.log_prob(tf.random.uniform(shape=[100, 12, 6, 1]))

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-331-cc1e20537792> in <module>
      1 print(tf_model)
----> 2 tf_model.log_prob(tf.random.uniform(shape=[100, 12, 6, 1]))

~/opt/anaconda3/envs/ICU_StateSegmentation/lib/python3.8/site-packages/tensorflow_probability/python/distributions/distribution.py in log_prob(self, value, name, **kwargs)
   1294         values of type `self.dtype`.
   1295     """
-> 1296     return self._call_log_prob(value, name, **kwargs)
   1297 
   1298   def _call_prob(self, value, name, **kwargs):

~/opt/anaconda3/envs/ICU_StateSegmentation/lib/python3.8/site-packages/tensorflow_probability/python/distributions/distribution.py in _call_log_prob(self, value, name, **kwargs)
   1276     with self._name_and_control_scope(name, value, kwargs):
   1277       if hasattr(self, '_log_prob'):
-> 1278         return self._log_prob(value, **kwargs)
   1279       if hasattr(self, '_prob'):
   1280         return tf.math.log(self._prob(value, **kwargs))

~/opt/anaconda3/envs/ICU_StateSegmentation/lib/python3.8/site-packages/tensorflow_probability/python/distributions/blockwise.py in _log_prob(self, x)
    335 
    336   def _log_prob(self, x):
--> 337     return self._distribution.log_prob(self._split_and_reshape_event(x))
    338 
    339   def _entropy(self):

~/opt/anaconda3/envs/ICU_StateSegmentation/lib/python3.8/site-packages/tensorflow_probability/python/distributions/distribution.py in log_prob(self, value, name, **kwargs)
   1294         values of type `self.dtype`.
   1295     """
-> 1296     return self._call_log_prob(value, name, **kwargs)
   1297 
   1298   def _call_prob(self, value, name, **kwargs):

~/opt/anaconda3/envs/ICU_StateSegmentation/lib/python3.8/site-packages/tensorflow_probability/python/distributions/distribution.py in _call_log_prob(self, value, name, **kwargs)
   1276     with self._name_and_control_scope(name, value, kwargs):
   1277       if hasattr(self, '_log_prob'):
-> 1278         return self._log_prob(value, **kwargs)
   1279       if hasattr(self, '_prob'):
   1280         return tf.math.log(self._prob(value, **kwargs))

~/opt/anaconda3/envs/ICU_StateSegmentation/lib/python3.8/site-packages/tensorflow_probability/python/distributions/batch_concat.py in _log_prob(self, x, **kwargs)
    440 
    441   def _log_prob(self, x, **kwargs):
--> 442     return self._call_split_concat('log_prob', x, **kwargs)
    443 
    444   def _prob(self, x, **kwargs):

~/opt/anaconda3/envs/ICU_StateSegmentation/lib/python3.8/site-packages/tensorflow_probability/python/distributions/batch_concat.py in _call_split_concat(self, fn, x, **kwargs)
    425 
    426   def _call_split_concat(self, fn, x, **kwargs):
--> 427     sample_shape_size, split_x = self._split_sample(x)
    428     result = [
    429         getattr(d, fn)(i, **kwargs)

~/opt/anaconda3/envs/ICU_StateSegmentation/lib/python3.8/site-packages/tensorflow_probability/python/distributions/batch_concat.py in _split_sample(self, x)
    233     original_shapes = ps.stack(all_batch_shapes, axis=0)
    234     all_compose_shapes = ps.gather(original_shapes, self._axis, axis=1)
--> 235     x_split = tf.split(x, all_compose_shapes, axis=sample_shape_size+self._axis)
    236     return sample_shape_size, x_split
    237 

~/opt/anaconda3/envs/ICU_StateSegmentation/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py in wrapper(*args, **kwargs)
    204     """Call target, and fall back on dispatchers if there is a TypeError."""
    205     try:
--> 206       return target(*args, **kwargs)
    207     except (TypeError, ValueError):
    208       # Note: convert_to_eager_tensor currently raises a ValueError, not a

~/opt/anaconda3/envs/ICU_StateSegmentation/lib/python3.8/site-packages/tensorflow/python/ops/array_ops.py in split(value, num_or_size_splits, axis, num, name)
   2143       raise ValueError("Cannot infer num from shape %s" % num_or_size_splits)
   2144 
-> 2145   return gen_array_ops.split_v(
   2146       value=value, size_splits=size_splits, axis=axis, num_split=num, name=name)
   2147 

~/opt/anaconda3/envs/ICU_StateSegmentation/lib/python3.8/site-packages/tensorflow/python/ops/gen_array_ops.py in split_v(value, size_splits, axis, num_split, name)
  10097       pass
  10098     try:
> 10099       return split_v_eager_fallback(
  10100           value, size_splits, axis, num_split=num_split, name=name, ctx=_ctx)
  10101     except _core._SymbolicException:

~/opt/anaconda3/envs/ICU_StateSegmentation/lib/python3.8/site-packages/tensorflow/python/ops/gen_array_ops.py in split_v_eager_fallback(value, size_splits, axis, num_split, name, ctx)
  10125   _inputs_flat = [value, size_splits, axis]
  10126   _attrs = ("num_split", num_split, "T", _attr_T, "Tlen", _attr_Tlen)
> 10127   _result = _execute.execute(b"SplitV", num_split, inputs=_inputs_flat,
  10128                              attrs=_attrs, ctx=ctx, name=name)
  10129   if _execute.must_record_gradient():

~/opt/anaconda3/envs/ICU_StateSegmentation/lib/python3.8/site-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     57   try:
     58     ctx.ensure_initialized()
---> 59     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
     60                                         inputs, attrs, num_outputs)
     61   except core._NotOkStatusException as e:

InvalidArgumentError: If there is only one output, it must have the same size as the input. Input size: 12 output size: 6 [Op:SplitV]
brianwa84 commented 3 years ago

Do you want [6]-shaped events? If so, you might try Blockwise([Independent(mixture1, 1), Independent(mixture2, 1)])

Feel free to send a PR, and open an issue about the specific issue you identified here (or re-title this issue).

On Tue, Sep 14, 2021 at 10:12 AM Lucas Miranda @.***> wrote:

One more (related) question. I tried tf.BatchConcat and it seems to work, but I get an error when attempting to compute log_prob. What am I doing wrong? Thanks!

First I instantiate a list of two distributions (I append to a pre-existing list because the second element depends on a condition on the original code):

emissions_per_feature = [ tfd.MixtureSameFamily( mixture_distribution=tfd.Categorical(logits=gm_logits), components_distribution=tfd.Normal( loc=trainable_means, scale=trainable_scales, allow_nan_stats=True ), ) ]

emissions_per_feature.append( tfd.MixtureSameFamily( mixture_distribution=tfd.Categorical(logits=bm_logits), components_distribution=tfd.Beta( concentration1=trainable_conc_1s, concentration0=trainable_conc_0s, force_probs_to_zero_outside_support=True, ), ) )

=> [<tfp.distributions.MixtureSameFamily 'MixtureSameFamily' batch_shape=[12, 4] event_shape=[] dtype=float32>, <tfp.distributions.MixtureSameFamily 'MixtureSameFamily' batch_shape=[12, 2] event_shape=[] dtype=float32>]

Then I apply BatchConcat and Blockwise:

emission_distribution = tfd.Blockwise(tfd.BatchConcat(emissions_per_feature, axis=1), name="emission_distribution" )

=> tfp.distributions.Blockwise("emission_distribution", batch_shape=[12, 6], event_shape=[1], dtype=float32)

This looks exactly like what I had in mind, but when I attempt to compute a log_prob, I get the following error:

tf_model.log_prob(tf.random.uniform(shape=[100, 12, 6, 1]))


InvalidArgumentError Traceback (most recent call last)

in 1 print(tf_model) ----> 2 tf_model.log_prob(tf.random.uniform(shape=[100, 12, 6, 1])) ~/opt/anaconda3/envs/ICU_StateSegmentation/lib/python3.8/site-packages/tensorflow_probability/python/distributions/distribution.py in log_prob(self, value, name, **kwargs) 1294 values of type `self.dtype`. 1295 """ -> 1296 return self._call_log_prob(value, name, **kwargs) 1297 1298 def _call_prob(self, value, name, **kwargs): ~/opt/anaconda3/envs/ICU_StateSegmentation/lib/python3.8/site-packages/tensorflow_probability/python/distributions/distribution.py in _call_log_prob(self, value, name, **kwargs) 1276 with self._name_and_control_scope(name, value, kwargs): 1277 if hasattr(self, '_log_prob'): -> 1278 return self._log_prob(value, **kwargs) 1279 if hasattr(self, '_prob'): 1280 return tf.math.log(self._prob(value, **kwargs)) ~/opt/anaconda3/envs/ICU_StateSegmentation/lib/python3.8/site-packages/tensorflow_probability/python/distributions/blockwise.py in _log_prob(self, x) 335 336 def _log_prob(self, x): --> 337 return self._distribution.log_prob(self._split_and_reshape_event(x)) 338 339 def _entropy(self): ~/opt/anaconda3/envs/ICU_StateSegmentation/lib/python3.8/site-packages/tensorflow_probability/python/distributions/distribution.py in log_prob(self, value, name, **kwargs) 1294 values of type `self.dtype`. 1295 """ -> 1296 return self._call_log_prob(value, name, **kwargs) 1297 1298 def _call_prob(self, value, name, **kwargs): ~/opt/anaconda3/envs/ICU_StateSegmentation/lib/python3.8/site-packages/tensorflow_probability/python/distributions/distribution.py in _call_log_prob(self, value, name, **kwargs) 1276 with self._name_and_control_scope(name, value, kwargs): 1277 if hasattr(self, '_log_prob'): -> 1278 return self._log_prob(value, **kwargs) 1279 if hasattr(self, '_prob'): 1280 return tf.math.log(self._prob(value, **kwargs)) ~/opt/anaconda3/envs/ICU_StateSegmentation/lib/python3.8/site-packages/tensorflow_probability/python/distributions/batch_concat.py in _log_prob(self, x, **kwargs) 440 441 def _log_prob(self, x, **kwargs): --> 442 return self._call_split_concat('log_prob', x, **kwargs) 443 444 def _prob(self, x, **kwargs): ~/opt/anaconda3/envs/ICU_StateSegmentation/lib/python3.8/site-packages/tensorflow_probability/python/distributions/batch_concat.py in _call_split_concat(self, fn, x, **kwargs) 425 426 def _call_split_concat(self, fn, x, **kwargs): --> 427 sample_shape_size, split_x = self._split_sample(x) 428 result = [ 429 getattr(d, fn)(i, **kwargs) ~/opt/anaconda3/envs/ICU_StateSegmentation/lib/python3.8/site-packages/tensorflow_probability/python/distributions/batch_concat.py in _split_sample(self, x) 233 original_shapes = ps.stack(all_batch_shapes, axis=0) 234 all_compose_shapes = ps.gather(original_shapes, self._axis, axis=1) --> 235 x_split = tf.split(x, all_compose_shapes, axis=sample_shape_size+self._axis) 236 return sample_shape_size, x_split 237 ~/opt/anaconda3/envs/ICU_StateSegmentation/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py in wrapper(*args, **kwargs) 204 """Call target, and fall back on dispatchers if there is a TypeError.""" 205 try: --> 206 return target(*args, **kwargs) 207 except (TypeError, ValueError): 208 # Note: convert_to_eager_tensor currently raises a ValueError, not a ~/opt/anaconda3/envs/ICU_StateSegmentation/lib/python3.8/site-packages/tensorflow/python/ops/array_ops.py in split(value, num_or_size_splits, axis, num, name) 2143 raise ValueError("Cannot infer num from shape %s" % num_or_size_splits) 2144 -> 2145 return gen_array_ops.split_v( 2146 value=value, size_splits=size_splits, axis=axis, num_split=num, name=name) 2147 ~/opt/anaconda3/envs/ICU_StateSegmentation/lib/python3.8/site-packages/tensorflow/python/ops/gen_array_ops.py in split_v(value, size_splits, axis, num_split, name) 10097 pass 10098 try: > 10099 return split_v_eager_fallback( 10100 value, size_splits, axis, num_split=num_split, name=name, ctx=_ctx) 10101 except _core._SymbolicException: ~/opt/anaconda3/envs/ICU_StateSegmentation/lib/python3.8/site-packages/tensorflow/python/ops/gen_array_ops.py in split_v_eager_fallback(value, size_splits, axis, num_split, name, ctx) 10125 _inputs_flat = [value, size_splits, axis] 10126 _attrs = ("num_split", num_split, "T", _attr_T, "Tlen", _attr_Tlen) > 10127 _result = _execute.execute(b"SplitV", num_split, inputs=_inputs_flat, 10128 attrs=_attrs, ctx=ctx, name=name) 10129 if _execute.must_record_gradient(): ~/opt/anaconda3/envs/ICU_StateSegmentation/lib/python3.8/site-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name) 57 try: 58 ctx.ensure_initialized() ---> 59 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name, 60 inputs, attrs, num_outputs) 61 except core._NotOkStatusException as e: InvalidArgumentError: If there is only one output, it must have the same size as the input. Input size: 12 output size: 6 [Op:SplitV] — You are receiving this because you commented. Reply to this email directly, view it on GitHub , or unsubscribe . Triage notifications on the go with GitHub Mobile for iOS or Android .
lucasmiranda42 commented 3 years ago

Thanks! I renamed the issue :)

No, I need [12, 6]-shaped batches and [1]-shaped events from combining these two:

[<tfp.distributions.MixtureSameFamily 'MixtureSameFamily' batch_shape=[12, 4] event_shape=[] dtype=float32>, 
 <tfp.distributions.MixtureSameFamily 'MixtureSameFamily' batch_shape=[12, 2] event_shape=[] dtype=float32>]