Cant slice along batch for MixtureSameDistribution produced by Keras layer #309

Open henrypinkard opened 1 year ago

henrypinkard commented 1 year ago

I'm trying to make some dense layers that output the parameters of Gaussian mixture (a mixture density network). I want to run a batch of data through the network (for speed), get out a batch of distributions, and then slice to work with only some elements of the batch at a time. If I were doing this with just tfp, I call:

import tensorflow_probability as tfp
tfd = tfp.distributions

num_mixture_components = 12
batch_size = 4

probs = np.random.rand(batch_size, num_mixture_components)
loc = np.random.rand(batch_size, num_mixture_components)
scale = np.random.rand(batch_size, num_mixture_components)

gm = tfd.MixtureSameFamily(
      components_distribution=tfd.Normal(loc=loc, scale=scale))  

# slice along batch

This works as expected giving

tfp.distributions.MixtureSameFamily("MixtureSameFamily", batch_shape=[4], event_shape=[], dtype=float64)
tfp.distributions.MixtureSameFamily("MixtureSameFamily", batch_shape=[], event_shape=[], dtype=float64)

However when I try the same thing with a mixture density network in Keras I get an error

import tensorflow.keras.layers as tfkl
import tensorflow.keras as tfk

num_mixture_components = 12

l = tfkl.Input(shape=(100))

# Make a fully connected network that outputs parameters of Gaussian mixture
mu = tfkl.Dense(units=num_mixture_components, activation=None)(l)
sigma = tfkl.Dense(units=num_mixture_components, activation='softplus')(l)
alpha = tfkl.Dense(units=num_mixture_components, activation='softmax')(l)
stacked = tfkl.Concatenate()([mu, sigma, alpha])

mixture = tfp.layers.MixtureNormal(num_mixture_components, 
                                        event_shape=[], name="test")(stacked)
model = tf.keras.Model(inputs=l, outputs=mixture)

out = model(np.random.rand(4, 100))

gm = out.tensor_distribution;

# slice along batch

Gives me this cryptic error:

tfp.distributions._MixtureSameFamily("MixtureSameFamily", batch_shape=[4], event_shape=[], dtype=float32)

TypeError                                 Traceback (most recent call last)
Cell In [72], line 27
     25 print(gm)
     26 # slice along batch
---> 27 print(gm[1])

File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow_probability/python/distributions/, in Distribution.__getitem__(self, slices)
    825 def __getitem__(self, slices):
    826   """Slices the batch axes of this distribution, returning a new instance.
    828   ```python
    850     dist: A new `tfd.Distribution` instance with sliced parameters.
    851   """
--> 852   return slicing.batch_slice(self, {}, slices)

File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow_probability/python/internal/, in batch_slice(batch_object, params_overrides, slices, bijector_x_event_ndims)
    217 slice_overrides_seq = slice_overrides_seq + [(slices, params_overrides)]
    218 # Re-doing the full sequence of slice+copy override work here enables
    219 # gradients all the way back to the original batch_objectribution's arguments.
--> 220 batch_object = _apply_slice_sequence(
    221     orig_batch_object,
    222     slice_overrides_seq,
    223     bijector_x_event_ndims=bijector_x_event_ndims)
    224 setattr(batch_object,
    225         PROVENANCE_ATTR,
    226         batch_object._no_dependency((orig_batch_object, slice_overrides_seq)))  # pylint: disable=protected-access
    227 return batch_object

File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow_probability/python/internal/, in _apply_slice_sequence(batch_object, slice_overrides_seq, bijector_x_event_ndims)
    177 """Applies a sequence of slice or copy-with-overrides operations to `batch_object`."""
    178 for slices, overrides in slice_overrides_seq:
--> 179   batch_object = _apply_single_step(
    180       batch_object,
    181       slices,
    182       overrides,
    183       bijector_x_event_ndims=bijector_x_event_ndims)
    184 return batch_object

File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow_probability/python/internal/, in _apply_single_step(batch_object, slices, params_overrides, bijector_x_event_ndims)
    166   override_dict = {}
    167 else:
--> 168   override_dict = _slice_params_to_dict(
    169       batch_object, slices, bijector_x_event_ndims=bijector_x_event_ndims)
    170 override_dict.update(params_overrides)
    171 parameters = dict(batch_object.parameters, **override_dict)

File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow_probability/python/internal/, in _slice_params_to_dict(batch_object, slices, bijector_x_event_ndims)
    150 else:
    151   batch_shape = batch_object.experimental_batch_shape_tensor(
    152       x_event_ndims=bijector_x_event_ndims)
--> 153 return batch_shape_lib.map_fn_over_parameters_with_event_ndims(
    154     batch_object,
    155     functools.partial(_slice_single_param,
    156                       slices=slices,
    157                       batch_shape=batch_shape),
    158     bijector_x_event_ndims=bijector_x_event_ndims)

File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow_probability/python/internal/, in map_fn_over_parameters_with_event_ndims(batch_object, fn, bijector_x_event_ndims, require_static, **parameter_kwargs)
    361     elif (properties.is_tensor
    362           and not tf.is_tensor(param)
    363           and not tf.nest.is_nested(param_event_ndims)):
    364       # As a last resort, try an explicit conversion.
    365       param = tensor_util.convert_nonref_to_tensor(param, name=param_name)
--> 367   results[param_name] = nest.map_structure_up_to(
    368       param, fn, param, param_event_ndims)
    369 return results

File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow/python/util/, in map_structure_up_to(shallow_tree, func, *inputs, **kwargs)
   1361 @tf_export("__internal__.nest.map_structure_up_to", v1=[])
   1362 def map_structure_up_to(shallow_tree, func, *inputs, **kwargs):
   1363   """Applies a function or op to a number of partially flattened inputs.
   1365   The `inputs` are flattened up to `shallow_tree` before being mapped.
   1433     `shallow_tree`.
   1434   """
-> 1435   return map_structure_with_tuple_paths_up_to(
   1436       shallow_tree,
   1437       lambda _, *values: func(*values),  # Discards the path arg.
   1438       *inputs,
   1439       **kwargs)

File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow/python/util/, in map_structure_with_tuple_paths_up_to(shallow_tree, func, *inputs, **kwargs)
   1526 flat_value_gen = (
   1527     flatten_up_to(  # pylint: disable=g-complex-comprehension
   1528         shallow_tree,
   1529         input_tree,
   1530         check_types,
   1531         expand_composites=expand_composites) for input_tree in inputs)
   1532 flat_path_gen = (
   1533     path
   1534     for path, _ in _yield_flat_up_to(shallow_tree, inputs[0], is_nested_fn))
-> 1535 results = [
   1536     func(*args, **kwargs) for args in zip(flat_path_gen, *flat_value_gen)
   1537 ]
   1538 return pack_sequence_as(structure=shallow_tree, flat_sequence=results,
   1539                         expand_composites=expand_composites)

File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow/python/util/, in <listcomp>(.0)
   1526 flat_value_gen = (
   1527     flatten_up_to(  # pylint: disable=g-complex-comprehension
   1528         shallow_tree,
   1529         input_tree,
   1530         check_types,
   1531         expand_composites=expand_composites) for input_tree in inputs)
   1532 flat_path_gen = (
   1533     path
   1534     for path, _ in _yield_flat_up_to(shallow_tree, inputs[0], is_nested_fn))
   1535 results = [
-> 1536     func(*args, **kwargs) for args in zip(flat_path_gen, *flat_value_gen)
   1537 ]
   1538 return pack_sequence_as(structure=shallow_tree, flat_sequence=results,
   1539                         expand_composites=expand_composites)

File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow/python/util/, in map_structure_up_to.<locals>.<lambda>(_, *values)
   1361 @tf_export("__internal__.nest.map_structure_up_to", v1=[])
   1362 def map_structure_up_to(shallow_tree, func, *inputs, **kwargs):
   1363   """Applies a function or op to a number of partially flattened inputs.
   1365   The `inputs` are flattened up to `shallow_tree` before being mapped.
   1433     `shallow_tree`.
   1434   """
   1435   return map_structure_with_tuple_paths_up_to(
   1436       shallow_tree,
-> 1437       lambda _, *values: func(*values),  # Discards the path arg.
   1438       *inputs,
   1439       **kwargs)

File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow_probability/python/internal/, in _slice_single_param(param, param_event_ndims, slices, batch_shape)
     85 """Slices into the batch shape of a single parameter.
     87 Args:
     98     `slices`.
     99 """
    100 # Broadcast the parmameter to have full batch rank.
--> 101 param = batch_shape_lib.broadcast_parameter_with_batch_shape(
    102     param, param_event_ndims, ps.ones_like(batch_shape))
    103 param_batch_shape = batch_shape_lib.get_batch_shape_tensor_part(
    104     param, param_event_ndims)
    105 # At this point the param should have full batch rank, *unless* it's an
    106 # atomic object like `tfb.Identity()` incapable of having any batch rank.

File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow_probability/python/internal/, in broadcast_parameter_with_batch_shape(param, param_event_ndims, batch_shape)
    270 base_shape = ps.concat([batch_shape,
    271                         ps.ones([param_event_ndims], dtype=np.int32)],
    272                        axis=0)
    273 if hasattr(param, '_broadcast_parameters_with_batch_shape'):
--> 274   return param._broadcast_parameters_with_batch_shape(base_shape)  # pylint: disable=protected-access
    275 elif hasattr(param, 'matmul'):
    276   # TODO(davmre): support broadcasting LinearOperator parameters.
    277   return param

File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow_probability/python/distributions/, in Distribution._broadcast_parameters_with_batch_shape(self, batch_shape)
    926 def _broadcast_parameters_with_batch_shape(self, batch_shape):
    927   """Broadcasts each parameter's batch shape with the given `batch_shape`.
    929   This is semantically equivalent to wrapping with the `BatchBroadcast`
    950       the given `batch_shape`.
    951   """
--> 952   return self.copy(
    953       **batch_shape_lib.broadcast_parameters_with_batch_shape(
    954           self, batch_shape))

File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow_probability/python/distributions/, in Distribution.copy(self, **override_parameters_kwargs)
    897 """Creates a deep copy of the distribution.
    899 Note: the copy distribution may continue to depend on the original
    909     `dict(self.parameters, **override_parameters_kwargs)`.
    910 """
    911 try:
    912   # We want track provenance from origin variables, so we use batch_slice
    913   # if this distribution supports slicing. See the comment on
    914   # PROVENANCE_ATTR in
--> 915   return slicing.batch_slice(self, override_parameters_kwargs, Ellipsis)
    916 except NotImplementedError:
    917   pass

File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow_probability/python/internal/, in batch_slice(batch_object, params_overrides, slices, bijector_x_event_ndims)
    217 slice_overrides_seq = slice_overrides_seq + [(slices, params_overrides)]
    218 # Re-doing the full sequence of slice+copy override work here enables
    219 # gradients all the way back to the original batch_objectribution's arguments.
--> 220 batch_object = _apply_slice_sequence(
    221     orig_batch_object,
    222     slice_overrides_seq,
    223     bijector_x_event_ndims=bijector_x_event_ndims)
    224 setattr(batch_object,
    225         PROVENANCE_ATTR,
    226         batch_object._no_dependency((orig_batch_object, slice_overrides_seq)))  # pylint: disable=protected-access
    227 return batch_object

File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow_probability/python/internal/, in _apply_slice_sequence(batch_object, slice_overrides_seq, bijector_x_event_ndims)
    177 """Applies a sequence of slice or copy-with-overrides operations to `batch_object`."""
    178 for slices, overrides in slice_overrides_seq:
--> 179   batch_object = _apply_single_step(
    180       batch_object,
    181       slices,
    182       overrides,
    183       bijector_x_event_ndims=bijector_x_event_ndims)
    184 return batch_object

File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/tensorflow_probability/python/internal/, in _apply_single_step(batch_object, slices, params_overrides, bijector_x_event_ndims)
    170 override_dict.update(params_overrides)
    171 parameters = dict(batch_object.parameters, **override_dict)
--> 172 return type(batch_object)(**parameters)

File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/, in decorate.<locals>.fun(*args, **kw)
    229 def fun(*args, **kw):
    230     if not kwsyntax:
--> 231         args, kw = fix(args, kw, sig)
    232     return caller(func, *(extras + args), **kw)

File ~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/, in fix(args, kwargs, sig)
    199 def fix(args, kwargs, sig):
    200     """
    201     Fix args and kwargs to be consistent with the signature
    202     """
--> 203     ba = sig.bind(*args, **kwargs)
    204     ba.apply_defaults()  # needed for test_dan_schult
    205     return ba.args, ba.kwargs

File ~/mambaforge/envs/phenotypes/lib/python3.10/, in Signature.bind(self, *args, **kwargs)
   3174 def bind(self, /, *args, **kwargs):
   3175     """Get a BoundArguments object, that maps the passed `args`
   3176     and `kwargs` to the function's signature.  Raises `TypeError`
   3177     if the passed arguments can not be bound.
   3178     """
-> 3179     return self._bind(args, kwargs)

File ~/mambaforge/envs/phenotypes/lib/python3.10/, in Signature._bind(self, args, kwargs, partial)
   3166         arguments[] = kwargs
   3167     else:
-> 3168         raise TypeError(
   3169             'got an unexpected keyword argument {arg!r}'.format(
   3170                 arg=next(iter(kwargs))))
   3172 return self._bound_arguments_cls(self, arguments)

TypeError: got an unexpected keyword argument 'reinterpreted_batch_ndims'

v2.11.0-rc2-17-gd5b57ca93e5 2.11.0 tensorflow-probability 0.19.0 Python 3.10.6 Ubuntu 18.04

Any ideas why this is happening and how to fix?

tilakrayal commented 1 year ago

@henrypinkard, The reinterpreted_batch_ndims parameter controls the number of batch dims which are absorbed as event dims; reinterpreted_batch_ndims <= len(batch_shape). For example, the _logprob function entails a reduce_sum over the rightmost reinterpreted_batch_ndims after calling the base distribution's log_prob. Also the batch dimension(s) index independent distributions, the resultant multivariate will have independent components. Also reinterpreted_batch_ndims is part of tensorflow probability(tfp), So i request you to check in this repo for more assistance. Thank you!

henrypinkard commented 1 year ago

Thanks for the explanation. I also opened an issue on TFP (, though no response yet

tilakrayal commented 1 year ago

@henrypinkard, Could ypu please feel free to move this issue to closed status, since it is already being tracked there? Thank you!

henrypinkard commented 1 year ago

Still no activity on the TFP issue or clarity no how this might be fixed